网络中加入注意力机制SE模块
SENet是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。SENet 是2017 ILSVR竞赛的冠军。
论文:
SE block的基本结构
- 给定一个输入 ,其特征通道数为C ,通过一系列卷积等一般变换后得到一个特征通道数为C的特征。
- Squeeze:顺着空间维度进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。
- Excitation:基于特征通道间的相关性,每个特征通道生成一个权重,用来代表特征通道的重要程度。
- Reweight:将Excitation输出的权重看做每个特征通道的重要性,然后通过乘法逐通道加权到之前的特征上,完成在通道维度上的对原始特征的重标定。
代码:
import torch import torch.nn as nn import math from torchvision import models class se_block(nn.Module): def __init__(self, channel, ratio=16): super(se_block, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // ratio, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // ratio, channel, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y class Mobilenet_v2(nn.Module): def __init__(self): super(Mobilenet_v2, self).__init__() model = models.mobilenet_v2(pretrained=True) # Remove linear and pool layers (since were not doing classification) modules = list(model.children())[:-1] self.resnet = nn.Sequential(*modules) self.pool = nn.AvgPool2d(kernel_size=7) self.fc = nn.Linear(1280, 16) self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax(dim=-1) self.attention = se_block(1280) # 1280 为上层输出通道 def forward(self, images): x = self.resnet(images) # [N, 1280, 1, 1] x=self.attention(x) # 此处加入se—block x = self.pool(x) x = x.view(-1, 1280) # [N, 1280] x = self.fc(x) return x if __name__=="__main__": input = torch.rand(2, 3, 224, 224) mode = Mobilenet_v2() out = mode(input) print(out.size())
小结:
1、SE网络可以通过堆叠SE模块得到。
2、SE模块也可以嵌入到现在几乎所有的网络结构中。