import matplotlib.pyplot as plt import numpy as np import tensorflow as tf from matplotlib.animation import FuncAnimation from tensorflow.keras.layers import Conv2D, InputLayer, Layer from tensorflow.keras.models import Sequential size = 10 full_size = (1, size, size, 1) padded_size = (1, size + 2, size + 2, 1) n_frames = 120 glider = ((1, 2), (2, 3), (3, 1), (3, 2), (3, 3)) # env = np.random.randint(0, 2, full_size) env = np.zeros(full_size, dtype=int) for pos in glider: env[(0,) + pos] = 1 class TorusPaddingLayer(Layer): def __init__(self, **kwargs): """Based on: https://stackoverflow.com/questions/39088489/tensorflow-periodic-padding""" super(TorusPaddingLayer, self).__init__(**kwargs) top_row = np.zeros((1, size)) bottom_row = np.zeros((1, size)) top_row[0, -1] = 1 bottom_row[-1, 0] = 1 self.pre = tf.convert_to_tensor(np.vstack((top_row, np.eye(size), bottom_row)), dtype=tf.int32) self.pre_T = tf.transpose(self.pre) def call(self, inputs): squeezed = tf.squeeze(inputs) result = self.pre @ squeezed @ self.pre_T result = tf.expand_dims(result, 0) result = tf.expand_dims(result, -1) return result torus_padding = TorusPaddingLayer() model = Sequential([InputLayer(input_shape=padded_size[1:]), Conv2D(1, 3, padding="valid", activation=None, use_bias=False, kernel_initializer="ones")]) frames = [] for i in range(n_frames): # 2D sliding window of 3x3 including summation # NOTE: convolve2d of scipy does support torus-padding but that's obviously not as cool as a neural network # TODO: Add custom layer to Sequential model, causing bugs at the moment padded = torus_padding(env) neighbours = (model(padded) - env) # Don't count the cell itself in the number of neighbours env = np.where((env & np.isin(neighbours, (2, 3))) | ((env == 0) & (neighbours == 3)), 1, 0) frames.append(env.squeeze()) fig = plt.figure() ax = plt.axes(xlim=(0, size), ylim=(0, size)) render = plt.imshow(frames[0], interpolation="none", cmap="binary") def animate(i: int): render.set_array(frames[i]) return [render] anim = FuncAnimation(fig, animate, frames=n_frames, interval=30, blit=True) plt.axis("off") plt.gca().invert_yaxis() anim.save("glider.gif", fps=30) plt.show()