Skip to content

Instantly share code, notes, and snippets.

@NormXU
Last active January 25, 2025 15:41
Show Gist options
  • Save NormXU/92fd2252652da312e46e4d02917f3ed8 to your computer and use it in GitHub Desktop.
Save NormXU/92fd2252652da312e46e4d02917f3ed8 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from scipy.ndimage import gaussian_filter
from PIL import Image
n_steps = 1000
bins = 100
num_dummy_images = 5000
n_animate = 50
num_pts = 5
def add_noise(x, dt=0.001, mu=1.0, sigma=1.0):
noise = np.random.randn(*x.shape)
dx = -mu * x * dt + sigma * np.sqrt(dt) * noise
return x + dx
def sample_two_gaussians(num_imgs):
""" Generates a combined sample from two Gaussian distributions.
half are drawn from a Gaussian distribution centered at 2, std=1
and the other half from a Gaussian distribution centered at -2, std=1
This is used to mock an image distribution
"""
samples1 = np.random.randn(num_imgs) + 2
samples2 = np.random.randn(num_imgs) - 2
samples = np.hstack((samples1, samples2)) # flatten the image
return samples
def normalize(image_array):
""" normalize image from uint8 to centered around 0
"""
return image_array / 255. - 0.5
def unnormalize(image_array):
""" unnormalize image to uint8 for display
"""
image_array = image_array + 0.5
return (image_array * 255).astype(np.uint8)
def run_simulation():
path = np.zeros((bins, n_steps))
single_path = np.zeros((n_steps, num_pts))
# sample dataset
xt = sample_two_gaussians(num_imgs=num_dummy_images)
# we pick one x_t to draw a single line track
single_xt = np.array([np.random.uniform(-3, 3, num_pts)])
for i in range(n_steps):
path[:, i] = np.histogram(xt, bins=bins, range=(-5, 5))[0] # for each image we record the freq count
xt = add_noise(xt, dt=1 / n_steps)
single_xt = add_noise(single_xt, dt=1 / n_steps)
single_path[i] = single_xt
smooth_path = gaussian_filter(path, sigma=3)
# the image resolution should not be too large
ori_img = Image.open('./pikachu.png')
ori_img = ori_img.resize((256, ori_img.height * 256 // ori_img.width))
x_0 = normalize(np.array(ori_img))
fig = plt.Figure(figsize=(15, 5))
fig.tight_layout()
ax = fig.add_subplot(3, 9, (4, 27))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow(smooth_path, interpolation='nearest', aspect='auto')
ax2 = fig.add_subplot(3, 9, (4, 27))
ax2.get_xaxis().set_visible(False)
ax2.get_yaxis().set_visible(False)
ax2.set_ylim(-5, 5)
ax2.set_xlim(0, n_steps)
ax2.patch.set_alpha(0.)
pt_plotter = [ax2.plot(single_path[: i], 'r-')[0] for i in range(num_pts)]
ax_img = fig.add_subplot(3, 9, (1, 21))
ax_img.get_xaxis().set_visible(False)
ax_img.get_yaxis().set_visible(False)
ax_img.margins(0, 0)
ax_img.imshow(x_0)
def update(step):
animation_interval = n_steps // n_animate
for idx in range(num_pts):
line = pt_plotter[idx]
line_y = single_path[:animation_interval * step, idx]
line.set_data(range(line_y.shape[0]), line_y)
ax_img.cla()
nonlocal x_0
ax_img.imshow(unnormalize(x_0))
x_0 = add_noise(x_0)
ani = animation.FuncAnimation(fig, update, range(n_animate))
ani.save('./forward_traj.gif', writer='imagemagick', fps=25)
if __name__ == '__main__':
run_simulation()
@NormXU
Copy link
Author

NormXU commented Jan 25, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment