import tensorflow as tf # To use : model.add(ResizeImage()) # e.g.: model.add(ResizeImage(model.layers[0].output_shape[1:3])) class ResizeImage(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super().__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. super().build(input_shape) def call(self, x, mask=None): return tf.image.resize_images(x, self.output_dim) def get_output_shape_for(self, input_shape): return (input_shape[0], self.output_dim[0], self.output_dim[1], input_shape[3])