深度学习5---经典神经网络结构


一、LeNet

1、LeNet网络结构

卷积核大小:55 步长是0 池化核大小:22 步长是2

输入层:33232 卷积运算--------6355 C1:62828 池化运算--------22,s=2 S2:61414 卷积运算--------16655 C3:161010 池化运算--------22,s=2 S4:1655 全连接 1)展平 400 2)400120矩阵 C5:120 全连接 12084 矩阵 F6:84 全连接 84*10矩阵 output:10

根据10个数字大小,来确定预测结果 eg:第0个数字是最大的,预测结果是数字0的概率最大

2、LeNet自实现

import torch.nn as nn
import torch
import pdb #调试
import numpy as np

#自定义的网络结构,都必须继承nn.Module
class Lenet(nn.Module):
    def __init__(self,num_classes=10):
        super(Lenet, self).__init__()
        #初始化的信息
        #定义各个卷积运算、池化运算、激活、全连接
        self.conv1=nn.Conv2d(in_channels=3,
                             out_channels=6,
                             kernel_size=(5,5),
                             stride=1,
                             padding=0
                             )
        self.pool=nn.MaxPool2d(kernel_size=(2,2),stride=2)
        self.conv2=nn.Conv2d(in_channels=6,
                             out_channels=16,
                             kernel_size=(5,5),
                             stride=1,
                             padding=0
        )
        self.fc1=nn.Linear(in_features=400,
                           out_features=120
                           )
        self.fc2 = nn.Linear(in_features=120,
                             out_features=84
                             )
        self.fc3 = nn.Linear(in_features=84,
                             out_features=num_classes
                             )
        self.relu=nn.ReLU(inplace=True)
        pass

    def forward(self,X):
        #前向传播的执行顺序
        #X是输入数据N*C*H*W(样本数*通道数*H*W)
        out=self.conv1(X)
        out=self.relu(out)
        out = self.pool(out)

        out=self.conv2(out)
        out = self.relu(out)
        out = self.pool(out)

        out=out.reshape(out.shape[0],-1)   #展平操作

        out=self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out
        pass

#实例化 网络结构
model=Lenet(num_classes=2)

input_data=torch.rand((1,3,32,32))

output=model(input_data)
y_pred=np.argmax(output.detach().numpy(),axis=1)
print("输出结果",output,"预测结果是类",y_pred)

二、其他网络结构

1、AlexNet

2、VGG

3、GoogLenet

经验分享 程序员 微信小程序 职场和发展