# grids = [create_grid(generated_batch[j]) for j in range(len(generated_batch))] # save_as_gif_animation(grids, output_animation_path) import math import numpy as np from PIL import Image def create_grid(image_arrays: np.ndarray): batch_size = image_arrays.shape[0] image_height = image_arrays.shape[1] image_width = image_arrays.shape[2] n_channels = image_arrays.shape[3] n_cols = n_rows = math.ceil(batch_size**0.5) n_samples = min(n_rows * n_cols, batch_size) canvas_height = image_height * n_rows canvas_width = image_width * n_cols canvas = np.zeros((canvas_height, canvas_width, n_channels)) for i in range(n_samples): row = i // n_cols col = i % n_rows canvas[ row * image_height : (row + 1) * image_height, col * image_width : (col + 1) * image_width, :, ] = image_arrays[i] return canvas def save_as_gif_animation(images, file_path): channels = images.shape[-1] if channels == 1: images = [img.squeeze() for img in images] Image.fromarray(images[0]).save( file_path, save_all=True, duration=10, append_images=[Image.fromarray(arr) for arr in images[1:]], loop=1, )