当前位置:首页 > 编程笔记 > 正文
已解决

SCAR的pytorch实现

来自网友在路上 172872提问 提问时间:2023-11-10 13:19:12阅读次数: 72

最佳答案 问答题库728位专家为你答疑解惑

本文所实现的网络来源于SCAR:Spatial-/Channel-wise Attention Regression Networks for Crowd Counting(Neurocompting 2019)

import torch;from torchvision import models
from torchvision.models import vgg16
import warnings;from torch import nn
warnings.filterwarnings("ignore")
vgg16 = vgg16(pretrained=True)
def initialize_weights(models):for model in models:real_init_weights(model)
import warnings
warnings.filterwarnings("ignore")
def real_init_weights(m):if isinstance(m, list):for mini_m in m:real_init_weights(mini_m)else:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight, std=0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):m.weight.data.normal_(0.0, std=0.01)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m,nn.Module):for mini_m in m.children():real_init_weights(mini_m)else:print( m )
class SCAR(torch.nn.Module):def __init__(self,loadwieght=False):super(SCAR,self).__init__()self.vgg10=vgg10if loadwieght==False:mod = models.vgg16(pretrained=True)initialize_weights(self.modules())self.vgg10.load_state_dict(mod.features[0:23].state_dict())self.dconv1=torch.nn.Conv2d(512,512,3,dilation=2,stride=1,padding=2)self.dconv2 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)self.dconv3 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)self.dconv4 = torch.nn.Conv2d(512, 256, 3, dilation=2, stride=1,padding=2)self.dconv5 = torch.nn.Conv2d(256, 128, 3, dilation=2, stride=1,padding=2)self.dconv6 = torch.nn.Conv2d(128, 64, 3, dilation=2, stride=1,padding=2)self.relu = torch.nn.functional.reluself.SAM=SAM()self.CAM=CAM()self.finalconv=torch.nn.Conv2d(128,1,1)self.upsample=torch.nn.functional.upsampledef forward(self,x):y=self.vgg10(x)y=self.relu(self.dconv1(y))y = self.relu(self.dconv1(y))y = self.relu(self.dconv2(y))y = self.relu(self.dconv3(y))y = self.relu(self.dconv4(y))y = self.relu(self.dconv5(y))y = self.relu(self.dconv6(y))y_sa=self.SAM(y)y_ca=self.CAM(y)y=torch.cat((y_ca,y_sa),dim=1)y=self.finalconv(y)y=self.upsample(y,scale_factor=8)#由于进行了三次池化 因此8倍上取样return yvgg10=torch.nn.Sequential(torch.nn.Conv2d(3,64,3,stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(64, 64, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(64, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(128, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(128, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),  #尝试不进行下采样以达到不进行上采样torch.nn.Conv2d(256, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),#torch.nn.MaxPool2d(2),
)class SAM(torch.nn.Module):def __init__(self):super(SAM,self).__init__()# SAM不改变输入到SAM中的x的shapeself.q=torch.nn.Conv2d(64,64,1)self.k = torch.nn.Conv2d(64, 64, 1)self.v=torch.nn.Conv2d(64, 64, 1)self.lamda=torch.nn.Conv2d(64,64,1)self.bn=torch.nn.BatchNorm2d(64)def forward(self,x):N, C, H, W = x.size()q=self.q(x).view((N,-1,H*W)).permute(0,2,1) # HW*Ck=self.q(x).view((N,-1,H*W))v=self.v(x).view((N,-1,H*W))mid=torch.bmm(q,k)attention=torch.nn.functional.softmax(mid,dim=-1)# HW*HWy=torch.bmm(v,attention)y=y.view((N,C,H,W))y=self.lamda(y)+xreturn yclass CAM(torch.nn.Module):def __init__(self):super(CAM,self).__init__()self.conv1=torch.nn.Conv2d(64,64,1)self.conv2 = torch.nn.Conv2d(64, 64, 1)self.bn = torch.nn.BatchNorm2d(64)def forward(self,x):N, C, H, W = x.size()q=self.conv1(x).view(N,C,-1)# C*HWk=self.conv1(x).view(N,-1,C) # HW*Cattention_pre=torch.bmm(q,k)# C*Cattention=torch.nn.functional.softmax(attention_pre,dim=-1)v=x.view(N,C,-1)cl2=torch.bmm(attention,v).view((N,C,H,W))cfinal=self.conv2(cl2)+xreturn cfinal
查看全文

99%的人还看了

相似问题

猜你感兴趣

版权申明

本文"SCAR的pytorch实现":http://eshow365.cn/6-37253-0.html 内容来自互联网,请自行判断内容的正确性。如有侵权请联系我们,立即删除!