Table of Contents
自定义Model
函数式API(Fucntional API)
本文代码基于 tensorflow2.0 python 3.7
tf.keras.Sequential
模型是层的简单堆叠,无法表示任意模型。
import tensorflow as tf from tensorflow import keras inputs = tf.keras.Input(shape=(128,)) # 构建一个输入张量 x = layers.Dense(256, activation='relu')(inputs) x = layers.Dense(512, activation='relu')(x) x = layers.Dense(128, activation='relu')(x) x = layers.Dense(64, activation='relu')(x) predictions = tf.nn.layers.Dense(10, activation='softmax')(x) model = tf.keras.Model(inputs=inputs, outputs=predictions) # 编译模型 model.compile(optimizer=tf.train.RMSPropOptimizer(0.001), loss='categorical_crossentropy', metrics=['accuracy']) model.fit(data, labels, batch_size=32, epochs=5)
模型子类化—实现自定义模型
"模型子类化"就是自己实现一个类来继承Model类,构建一个Model类的子类,
需要实现两个方法,即:
__init__() call()
通过对 tf.keras.Model
进行子类化并定义自己的前向传播来构建完全可自定义的模型。
在 init 方法中创建层并将它们设置为类实例的属性
在 call 方法中定义前向传播
下面给出典型的ResNet网络代码:
import os import tensorflow as tf import numpy as np from tensorflow import keras def conv3x3(channels, stride=1, kernel=(3, 3)): return keras.layers.Conv2D(channels, kernel, strides=stride, padding='same', use_bias=False, kernel_initializer=tf.random_normal_initializer()) class ResnetBlock(keras.Model): def __init__(self, channels, strides=1, residual_path=False): super().__init__() self.channels = channels self.strides = strides self.residual_path = residual_path self.conv1 = conv3x3(channels, strides) self.bn1 = keras.layers.BatchNormalization() self.conv2 = conv3x3(channels) self.bn2 = keras.layers.BatchNormalization() if residual_path: self.down_conv = conv3x3(channels, strides, kernel=(1, 1)) self.down_bn = tf.keras.layers.BatchNormalization() def call(self, inputs, training=None): residual = inputs x = self.bn1(inputs, training=training) x = tf.nn.relu(x) x = self.conv1(x) x = self.bn2(x, training=training) x = tf.nn.relu(x) x = self.conv2(x) if self.residual_path: residual = self.down_bn(inputs, training=training) residual = tf.nn.relu(residual) residual = self.down_conv(residual) x = x + residual return x class ResNet(keras.Model): def __init__(self, block_list, num_classes, initial_filters=16, **kwargs): super().__init__(**kwargs) self.num_blocks = len(block_list) self.block_list = block_list self.in_channels = initial_filters self.out_channels = initial_filters self.conv_initial = conv3x3(self.out_channels) self.blocks = keras.models.Sequential(name='dynamic-blocks') for block_id in range(len(block_list)): for layer_id in range(block_list[block_id]): if block_id != 0 and layer_id == 0: block = ResnetBlock(self.out_channels, strides=2, residual_path=True) else: if self.in_channels != self.out_channels: residual_path = True else: residual_path = False block = ResnetBlock(self.out_channels, residual_path=residual_path) self.in_channels = self.out_channels self.blocks.add(block) self.out_channels *= 2 self.final_bn = keras.layers.BatchNormalization() self.avg_pool = keras.layers.GlobalAveragePooling2D() self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None): out = self.conv_initial(inputs) out = self.blocks(out, training=training) out = self.final_bn(out, training=training) out = tf.nn.relu(out) out = self.avg_pool(out) out = self.fc(out) return out if __name__ == "__main__": model = ResNet([2, 2, 2], 10) model.build(input_shape=(None, 28, 28, 1)) model.summary()
总结:一般情况下,简单的应用可以直接使用函数式API编程,对于复杂的网络的定义和训练可以使用类继承的方式,这样的代码逻辑和封装性较好