快捷搜索: 王者荣耀 脱发

双分支CNN卷积神经网络搭建 TensorFlow

在掌握了的搭建方法后,我们来了解一下双分支CNN卷积网络的搭建基本方法。

下面我们来用代码构建一个如下结构的网络:

现在网络上面关于传统神经网络搭建的讲解有很多,但是对于多分支网络的讲解很少,所以很多人第一次看到这种网络会一头雾水并感到无从下手,其实这样的网络搭建并不算难。让大家感到困难的点在于如何将两个分支网络分别独立训练并在后面汇总到一起。

其实只需要像传统神经网络那样分别构建两个分支的网络(注意用不同的变量名区分,不要产生变量交集)然后再用concatenate或concat函数拼接即可

具体网络搭建以及调用示例代码如下:

输入数据预处理:

由于是多输入,因此需要对每一个输入数据定义标签

input_1 = tf.keras.Input(shape=input_1.shape, name="input_1")
input_2 = tf.keras.Input(shape=input_2.shape, name="input_2")#分别定义input_1和input_2的标签

模型搭建:

input_1_features=tf.keras.layers.Conv2D(filters=64,
                                      kernel_size=[4,4],
                                      padding=same,
                                      activation=tf.nn.relu)(input_1)#将输入数据input1输入到网络中
 
input_1_features= tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2)(input_1_features)
 
input_2_features=tf.keras.layers.Conv2D(filters=64,
                                      kernel_size=[4,4],
                                      padding=same,
                                      activation=tf.nn.relu)(input_2)#将输入数据input2输入到网络中
 
input_2_features= tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2)(input_2_features)

x=tf.keras.layers.concatenate([input_1_features,input_2_features])#将两个处理好的分支做concatenate处理,变为一个数据块
department_pred = tf.keras.layers.Dense(num,name="output", activation=softmax)(x)#做展平处理最终的分类结果。将输出标记为"output"

模型定义:

model = tf.keras.Model(
        inputs=[input_1,input_2],
        outputs=department_pred,
        )

模型编译:

optimizer=tf.keras.optimizers.Adam(learning_rate=0.01)#定义学习率
model.compile(optimizer=optimizer,#导入上面的学习率
              loss=categorical_crossentropy,#选择损失函数
              metrics=[accuracy]#在训练时输出对训练集的精确度(可删)
             )

模型训练并输出结果:

model.fit(
        {"input_1": input_1,"input_2": input_2},#导入训练集input_1,input_2
        {"output":Y_train},#导入训练集标签(这里的数据和标签分别对应之前模型定义时赋予各数据的标签)
        epochs=2000,#迭代次数设置为2000
        batch_size=32,#每次送入32个数据进行训练(如果少于32则会直接带入当前全部)
        shuffle=True#打乱数据集顺序,防止过拟合
        validation_data=([X_test_1,X_test_2], Y_test),#导入验证集(注意验证集同样需要与训练集有完全相同的格式和数量)
        )

最后要注意一点就是这种网络要保证输入数据input_1和input_2的形状完全一样,否则无法进行concatenate操作。

经验分享 程序员 微信小程序 职场和发展