import torch import numpy as np import matplotlib.pyplot as plt from matplotlib import colormaps from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D from transformers import GPT2LMHeadModel, GPT2Tokenizer from sklearn.decomposition import PCA # === CONFIG === prompt = "1 2 3 4 5 " num_generate = 200 # adjust if memory constrained alpha_min = 0.05 device = "mps" # change to "cuda" or "cpu" as needed # === Load GPT-2 tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True).to(device).eval() # === Generate + sync hidden states + text input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device) generated_ids = input_ids.clone() hidden_vectors = [] frame_texts = [] with torch.no_grad(): for _ in range(num_generate): outputs = model(generated_ids, output_hidden_states=True) # collect all 12 layer vectors for the current token for layer_hidden in outputs.hidden_states[1:]: last_vector = layer_hidden[0, -1, :] hidden_vectors.append(last_vector.cpu().numpy()) # update the current text shown on screen frame_texts.append(tokenizer.decode(generated_ids[0])) # generate next token (sampling) logits = outputs.logits[:, -1, :] probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=1) # === Project to 3D hidden_matrix = np.stack(hidden_vectors) points_3d = PCA(n_components=3).fit_transform(hidden_matrix) # === Build animation frames frames = [] cmap = colormaps["plasma"] for i in range(num_generate): end = (i + 1) * 12 segment = points_3d[:end] alphas = [alpha_min + (1 - alpha_min) * (j / end) for j in range(end - 1)] frames.append((segment, alphas, frame_texts[i])) # === Plot & Animate fig = plt.figure(figsize=(30, 20)) fig.subplots_adjust(top=0.80) # Leave space for text ax = fig.add_subplot(111, projection='3d') text_handle = ax.text2D( 0.5, 0.60, "", transform=ax.transAxes, ha="center", va="top", fontsize=14, color='white', wrap=True ) def update(frame_idx): ax.cla() ax.set_facecolor("#000000") ax.axis('off') segment, alphas, current_text = frames[frame_idx] for i in range(len(segment) - 1): x = [segment[i, 0], segment[i + 1, 0]] y = [segment[i, 1], segment[i + 1, 1]] z = [segment[i, 2], segment[i + 1, 2]] ax.plot(x, y, z, color=cmap(i / len(points_3d)), linewidth=2, alpha=alphas[i]) last_tokens = tokenizer.decode(tokenizer.encode(current_text)[-10:]) text_handle.set_text(last_tokens) ax.add_artist(text_handle) return [] ani = FuncAnimation(fig, update, frames=len(frames), blit=False, interval=100, repeat=True) # === Save or show ani.save("gpt2_synced_trail.mp4", writer="ffmpeg", dpi=200) # plt.show() # Uncomment to preview live instead of saving