import keras from keras import backend as K from keras.layers.convolutional import Conv2D, Conv2DTranspose from keras.layers import Input, Dense, Activation from keras.layers import concatenate # functional interface from keras.models import Model from keras.layers.advanced_activations import LeakyReLU N_INPUT = 512 def get_unet(): n_ch_exps = [4, 5, 6, 6, 7, 7] kernels = (5, 5) if K.image_data_format() == 'channels_first': ch_axis = 1 input_shape = (1, N_INPUT, N_INPUT) elif K.image_data_format() == 'channels_last': ch_axis = 3 input_shape = (N_INPUT, N_INPUT, 1) inp = Input(shape=input_shape) encodeds = [] # encoder enc = inp for l_idx, n_ch in enumerate(n_ch_exps): enc = Conv2D(2 ** n_ch, kernels, strides=(2, 2), padding='same', kernel_initializer='he_normal')(enc) enc = LeakyReLU(name='encoded_{}'.format(l_idx), alpha=0.2)(enc) encodeds.append(enc) # decoder dec = enc decoder_n_chs = n_ch_exps[::-1][1:] for l_idx, n_ch in enumerate(decoder_n_chs): l_idx_rev = len(n_ch_exps) - l_idx - 2 # dec = Conv2DTranspose(2 ** n_ch, kernels, strides=(2, 2), padding='same', kernel_initializer='he_normal', activation='relu', name='decoded_{}'.format(l_idx))(dec) dec = concatenate([dec, encodeds[l_idx_rev]], axis=ch_axis) outp = Conv2DTranspose(1, kernels, strides=(2, 2), padding='same', kernel_initializer='glorot_normal', activation='sigmoid', name='decoded_{}'.format(l_idx + 1))(dec) unet = Model(inputs=inp, outputs=outp) return unet if __name__ == "__main__": model = get_unet() model.summary()