快捷搜索: 王者荣耀 脱发

Pytorch框架 || torch.nn.modules.Module(nn.Module)

1 一个简单的网络

  1. 一个Pytorch模型应该以类的形式出现
  2. Pytorch训练模型应该是nn.Module的子类
  3. 一个训练模型最少包含init和forward(初始化和前向传播)两个过程。
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

2 nn.Module.init_weight()

    这个代码是SeNet的代码,放在这里学习init_weight
import numpy as np
import torch
from torch import nn
from torch.nn import init


class SEAttention(nn.Module):

    def __init__(self, channel=512, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局均值池化  输出的是c×1×1
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),  # channel // reduction代表通道压缩
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),  # 还原
            nn.Sigmoid()
        )

    def init_weights(self):
        for m in self.modules():
            print(m)  # 没运行到这儿
            if isinstance(m, nn.Conv2d):  # 判断类型函数——:m是nn.Conv2d类吗?
                init.kaiming_normal_(m.weight, mode=fan_out)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()  # 50×512×7×7
        y = self.avg_pool(x).view(b, c)  # ① maxpool之后得:50×512×1×1 ② view形状得到50×512
        y = self.fc(y).view(b, c, 1, 1)  # 50×512×1×1
        return x * y.expand_as(x)  # 根据x.size来扩展y


if __name__ == __main__:
    input = torch.randn(50, 512, 7, 7)
    se = SEAttention(channel=512, reduction=8)  # 实例化模型se
    output = se(input)
    print(output.shape)

2.1 kaiming 高斯初始化

    使得每一个卷积层的输出方差都为1,权值的初始化方法如下:
torch.nn.init.kaiming_normal_(tensor, a=0, mode=fan_in, nonlinearity=leaky_relu)
经验分享 程序员 微信小程序 职场和发展