Last active
January 25, 2025 15:41
-
-
Save NormXU/92fd2252652da312e46e4d02917f3ed8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib import animation | |
| from scipy.ndimage import gaussian_filter | |
| from PIL import Image | |
| n_steps = 1000 | |
| bins = 100 | |
| num_dummy_images = 5000 | |
| n_animate = 50 | |
| num_pts = 5 | |
| def add_noise(x, dt=0.001, mu=1.0, sigma=1.0): | |
| noise = np.random.randn(*x.shape) | |
| dx = -mu * x * dt + sigma * np.sqrt(dt) * noise | |
| return x + dx | |
| def sample_two_gaussians(num_imgs): | |
| """ Generates a combined sample from two Gaussian distributions. | |
| half are drawn from a Gaussian distribution centered at 2, std=1 | |
| and the other half from a Gaussian distribution centered at -2, std=1 | |
| This is used to mock an image distribution | |
| """ | |
| samples1 = np.random.randn(num_imgs) + 2 | |
| samples2 = np.random.randn(num_imgs) - 2 | |
| samples = np.hstack((samples1, samples2)) # flatten the image | |
| return samples | |
| def normalize(image_array): | |
| """ normalize image from uint8 to centered around 0 | |
| """ | |
| return image_array / 255. - 0.5 | |
| def unnormalize(image_array): | |
| """ unnormalize image to uint8 for display | |
| """ | |
| image_array = image_array + 0.5 | |
| return (image_array * 255).astype(np.uint8) | |
| def run_simulation(): | |
| path = np.zeros((bins, n_steps)) | |
| single_path = np.zeros((n_steps, num_pts)) | |
| # sample dataset | |
| xt = sample_two_gaussians(num_imgs=num_dummy_images) | |
| # we pick one x_t to draw a single line track | |
| single_xt = np.array([np.random.uniform(-3, 3, num_pts)]) | |
| for i in range(n_steps): | |
| path[:, i] = np.histogram(xt, bins=bins, range=(-5, 5))[0] # for each image we record the freq count | |
| xt = add_noise(xt, dt=1 / n_steps) | |
| single_xt = add_noise(single_xt, dt=1 / n_steps) | |
| single_path[i] = single_xt | |
| smooth_path = gaussian_filter(path, sigma=3) | |
| # the image resolution should not be too large | |
| ori_img = Image.open('./pikachu.png') | |
| ori_img = ori_img.resize((256, ori_img.height * 256 // ori_img.width)) | |
| x_0 = normalize(np.array(ori_img)) | |
| fig = plt.Figure(figsize=(15, 5)) | |
| fig.tight_layout() | |
| ax = fig.add_subplot(3, 9, (4, 27)) | |
| ax.get_xaxis().set_visible(False) | |
| ax.get_yaxis().set_visible(False) | |
| ax.imshow(smooth_path, interpolation='nearest', aspect='auto') | |
| ax2 = fig.add_subplot(3, 9, (4, 27)) | |
| ax2.get_xaxis().set_visible(False) | |
| ax2.get_yaxis().set_visible(False) | |
| ax2.set_ylim(-5, 5) | |
| ax2.set_xlim(0, n_steps) | |
| ax2.patch.set_alpha(0.) | |
| pt_plotter = [ax2.plot(single_path[: i], 'r-')[0] for i in range(num_pts)] | |
| ax_img = fig.add_subplot(3, 9, (1, 21)) | |
| ax_img.get_xaxis().set_visible(False) | |
| ax_img.get_yaxis().set_visible(False) | |
| ax_img.margins(0, 0) | |
| ax_img.imshow(x_0) | |
| def update(step): | |
| animation_interval = n_steps // n_animate | |
| for idx in range(num_pts): | |
| line = pt_plotter[idx] | |
| line_y = single_path[:animation_interval * step, idx] | |
| line.set_data(range(line_y.shape[0]), line_y) | |
| ax_img.cla() | |
| nonlocal x_0 | |
| ax_img.imshow(unnormalize(x_0)) | |
| x_0 = add_noise(x_0) | |
| ani = animation.FuncAnimation(fig, update, range(n_animate)) | |
| ani.save('./forward_traj.gif', writer='imagemagick', fps=25) | |
| if __name__ == '__main__': | |
| run_simulation() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Adapted from https://github.com/min-hieu/Tutorial_4