"Open

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import (
 AutoModelForCausalLM,
 AutoTokenizer,
)


In [None]:

def print_probability_distribution(current, probabilities, tokenizer, top_n=20, bar_width=50):
 """
 Print the top N tokens and their probabilities in ASCII format.

 Parameters:
 - current: Current context as a string.
 - probabilities: Probability distribution over the vocabulary.
 - tokenizer: Tokenizer to decode token IDs to tokens.
 - top_n: Number of top tokens to display.
 - bar_width: Width of the ASCII bar representing probabilities.
 """
 # Get top N tokens and their probabilities
 top_indices = np.argsort(probabilities)[-top_n:][::-1]
 top_probs = probabilities[top_indices]
 top_tokens = [tokenizer.decode([i]).strip() for i in top_indices]

 # Find the next token (highest probability token)
 max_token = top_tokens[0]

 # Display the current context
 print(f"Context: {current}")
 print(f"Next Token Prediction: '{max_token}'\n")

 # Print the top N tokens and their probabilities as an ASCII bar chart
 for token, prob in zip(top_tokens, top_probs):
 bar = "#" * int(prob * bar_width)
 print(f"{token:>15} | {bar} {prob:.4f}")

def plot_probability_distribution(current, probabilities, tokenizer, top_n=20):
 # Get top N tokens and their probabilities
 top_indices = np.argsort(probabilities)[-top_n:][::-1]
 top_probs = probabilities[top_indices]
 top_tokens = [tokenizer.decode([i]) for i in top_indices]

 # Find the next token (highest probability token)
 max_token = tokenizer.decode([top_indices[0]])

 # Plot
 plt.figure(figsize=(12, 7))
 bars = plt.bar(top_tokens, top_probs, color="blue")
 bars[0].set_color("red") # Highlight the next token

 # Add the current context inside the graph
 plt.text(
 0.5,
 0.9,
 f"Context: {current}\nNext Token: {max_token}",
 ha="center",
 va="center",
 transform=plt.gca().transAxes,
 fontsize=12,
 bbox=dict(facecolor="white", alpha=0.8, edgecolor="black"),
 )

 plt.xlabel("Tokens")
 plt.ylabel("Probabilities")
 plt.xticks(rotation=45)
 plt.tight_layout()
 plt.show()



In [None]:
model_name = 'gpt2'
#model_name = "meta-llama/Llama-3.2-1B-Instruct" # try with this also
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
 model_name,
 torch_dtype=torch.float16,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()


GPT2LMHeadModel(
 (transformer): GPT2Model(
 (wte): Embedding(50257, 768)
 (wpe): Embedding(1024, 768)
 (drop): Dropout(p=0.1, inplace=False)
 (h): ModuleList(
 (0-11): 12 x GPT2Block(
 (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
 (attn): GPT2SdpaAttention(
 (c_attn): Conv1D(nf=2304, nx=768)
 (c_proj): Conv1D(nf=768, nx=768)
 (attn_dropout): Dropout(p=0.1, inplace=False)
 (resid_dropout): Dropout(p=0.1, inplace=False)
 )
 (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
 (mlp): GPT2MLP(
 (c_fc): Conv1D(nf=3072, nx=768)
 (c_proj): Conv1D(nf=768, nx=3072)
 (act): NewGELUActivation()
 (dropout): Dropout(p=0.1, inplace=False)
 )
 )
 )
 (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
 )
 (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
# Get the vocabulary as a dictionary {token: token_id}
vocab = tokenizer.get_vocab()
# Print the vocabulary size
print(f"Vocabulary Size: {len(vocab)}")
prompt_template = "I love New"

if model_name == "meta-llama/Llama-3.2-1B-Instruct" :
 # use its format as we are using the Instuct model, the prompt template is as below
 system_message ="You complete sentences with funny words"
 question = "Complete the sentence I love New"
 prompt_template=f'''
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>
 {system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>
 {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 '''

print(f"Original Text: {prompt_template}")
input_id_list = list(tokenizer.encode(prompt_template))
text =input_id_list
generated_tokens = []

# Set the number of tokens to generate
N = 10

# Iterative generation
for i in range(N):
 current_input = torch.tensor([text], dtype=torch.long)

 # Forward pass to get logits
 with torch.no_grad():
 outputs = model(current_input.to(device))
 logits = outputs.logits

 # Get probabilities for the last token
 probabilities = torch.softmax(logits[0, -1], dim=0).cpu().numpy()
 probabilities /= probabilities.sum() # Normalize

 # Find the token with the maximum probability
 max_token_id = np.argmax(probabilities)
 max_token = tokenizer.decode([max_token_id])
 generated_tokens.append(max_token)

 # Append the generated token to the input for the next iteration
 text.append(max_token_id)

 # Decode current context for display
 current = tokenizer.decode(text)
 print(f"Decoded Context: {current}")
 print(f"Max Probability Token: '{max_token}' (ID: {max_token_id} word {i})")

 # Plot the probability distribution
 #plot_probability_distribution(current, probabilities, tokenizer, top_n=10)
 print_probability_distribution(current, probabilities, tokenizer, top_n=10)

# Final Output
final_generated_text = tokenizer.decode(text)
print(f"\nFinal Generated Text: {final_generated_text}")


Vocabulary Size: 50257
Original Text: I love New
Decoded Context: I love New York
Max Probability Token: ' York' (ID: 1971 word 0)
Context: I love New York
Next Token Prediction: 'York'

 York | ##################### 0.4355
 Orleans | #### 0.0972
 Zealand | #### 0.0885
 England | ## 0.0504
 Jersey | # 0.0393
 Year | # 0.0278
 Yorkers | 0.0191
 Mexico | 0.0144
 Hampshire | 0.0096
 Years | 0.0085
Decoded Context: I love New York.
Max Probability Token: '.' (ID: 13 word 1)
Context: I love New York.
Next Token Prediction: '.'

 . | ########## 0.2020
 , | ######## 0.1783
 and | #### 0.0955
 City | ### 0.0792
 ! | # 0.0398
 ," | # 0.0351
 ." | # 0.0291
 !" | # 0.0273
 so | # 0.0227
 's | # 0.0227
Decoded Context: I love New York. I
Max Probability Token: ' I' (ID: 314 word 2)
Context: I love New York. I
Next Token Prediction: 'I'

 I | ################ 0.3269
 It | ###### 0.1202
 | ## 0.0533
 We | ## 0.0471
 But | ## 0.0416
 And | # 0.0390
 The | # 0.0268
 You | 0.0184
 So | 0.0163
 My | 0.0