【Tensorflow2】tf.keras中Model实例化的方式
最近要用到Tensorflow了,回顾一下。
参考:
Model
的两种实例化方式
1. 功能性API
def MyModel(input_shape): input1 = tf.keras.Input(shape=input_shape,name="input1") X = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")(input1) model = tf.keras.Model(inputs=input1,outputs=X,name="my_model") return model
2. 继承tf.keras.Model
class MyModel(tf.keras.Model): def __init__(self,input_shape): super(MyModel,self).__init__() # 必须在首行明确 self.input1 = tf.keras.Input(shape=input_shape,name="input1") self.dense1 = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1") self.out1 = self.call(self.input1) # reinitialize super(MyModel,self).__init__( inputs=self.input1, outputs=self.out1, name="my_model" ) # 前向转播过程 def call(self,inputs): """ 参数: input - 输入,形状必须为 self.input_shape """ x = self.dense1(inputs) return x
summary输出
执行以下代码:
if __name__ == __main__: model = MyModel((100,)) model.summary()
输出如下,
- 功能性API
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input1 (InputLayer) [(None, 100)] 0 _________________________________________________________________ dense1 (Dense) (None, 4) 404 ================================================================= Total params: 404 Trainable params: 404 Non-trainable params: 0 _________________________________________________________________
- 继承Model
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input1 (InputLayer) [(None, 100)] 0 _________________________________________________________________ dense1 (Dense) (None, 4) 404 ================================================================= Total params: 404 Trainable params: 404 Non-trainable params: 0 _________________________________________________________________
可以看到,summary()输出相同。
model.save
与 load_model
- 功能性API
if __name__ == __main__: model = MyModel((100,)) model.save("mymodel.h5") # 保存模型 model = tf.keras.models.load_model("mymodel.h5") # 加载模型 model.summary()
- 继承Model
加载模型时,需要明确custom_objects
if __name__ == __main__: model = MyModel((100,)) model.save("mymodel.h5") # 保存模型 # 加载模型,需要明确custom_objects model = tf.keras.models.load_model("mymodel.h5",custom_objects={ "MyModel":MyModel}) model.summary()