Keras Model

自定义Model

函数式API(Fucntional API)

本文代码基于 tensorflow2.0 python 3.7
tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。

import tensorflow as tf
from tensorflow import keras

inputs = tf.keras.Input(shape=(128,))  # 构建一个输入张量
x = layers.Dense(256, activation='relu')(inputs)
x = layers.Dense(512, activation='relu')(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(64, activation='relu')(x)
predictions = tf.nn.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 编译模型
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(data, labels, batch_size=32, epochs=5)

模型子类化—实现自定义模型

"模型子类化"就是自己实现一个类来继承Model类,构建一个Model类的子类,

需要实现两个方法,即:

__init__()
call(

通过对 tf.keras.Model 进行子类化并定义自己的前向传播来构建完全可自定义的模型。

init 方法中创建层并将它们设置为类实例的属性
在 call 方法中定义前向传播
下面给出典型的ResNet网络代码:

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras

def conv3x3(channels, stride=1, kernel=(3, 3)):
    return keras.layers.Conv2D(channels, kernel, strides=stride, padding='same',
                               use_bias=False,
                               kernel_initializer=tf.random_normal_initializer())

class ResnetBlock(keras.Model):

    def __init__(self, channels, strides=1, residual_path=False):
        super().__init__()

        self.channels = channels
        self.strides = strides
        self.residual_path = residual_path

        self.conv1 = conv3x3(channels, strides)
        self.bn1 = keras.layers.BatchNormalization()
        self.conv2 = conv3x3(channels)
        self.bn2 = keras.layers.BatchNormalization()

        if residual_path:
            self.down_conv = conv3x3(channels, strides, kernel=(1, 1))
            self.down_bn = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=None):
        residual = inputs

        x = self.bn1(inputs, training=training)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)

        if self.residual_path:
            residual = self.down_bn(inputs, training=training)
            residual = tf.nn.relu(residual)
            residual = self.down_conv(residual)

        x = x + residual
        return x

class ResNet(keras.Model):

    def __init__(self, block_list, num_classes, initial_filters=16, **kwargs):
        super().__init__(**kwargs)

        self.num_blocks = len(block_list)
        self.block_list = block_list

        self.in_channels = initial_filters
        self.out_channels = initial_filters
        self.conv_initial = conv3x3(self.out_channels)

        self.blocks = keras.models.Sequential(name='dynamic-blocks')

        for block_id in range(len(block_list)):
            for layer_id in range(block_list[block_id]):

                if block_id != 0 and layer_id == 0:
                    block = ResnetBlock(self.out_channels,
                                        strides=2, residual_path=True)
                else:
                    if self.in_channels != self.out_channels:
                        residual_path = True
                    else:
                        residual_path = False
                    block = ResnetBlock(self.out_channels,
                                        residual_path=residual_path)

                self.in_channels = self.out_channels

                self.blocks.add(block)

            self.out_channels *= 2

        self.final_bn = keras.layers.BatchNormalization()
        self.avg_pool = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(num_classes)

    def call(self, inputs, training=None):
        out = self.conv_initial(inputs)
        out = self.blocks(out, training=training)
        out = self.final_bn(out, training=training)
        out = tf.nn.relu(out)
        out = self.avg_pool(out)
        out = self.fc(out)
        return out

if __name__ == "__main__":
    model = ResNet([2, 2, 2], 10)
    model.build(input_shape=(None, 28, 28, 1))
    model.summary()

总结:一般情况下,简单的应用可以直接使用函数式API编程,对于复杂的网络的定义和训练可以使用类继承的方式,这样的代码逻辑和封装性较好


如果你觉得这篇文章对你有帮助,不妨请我喝杯咖啡,鼓励我创造更多!