Skip to content

Instantly share code, notes, and snippets.

@alexcpn
Created November 28, 2024 13:05
Show Gist options
  • Select an option

  • Save alexcpn/36e90cb3c78695e6b09eb97cdb277414 to your computer and use it in GitHub Desktop.

Select an option

Save alexcpn/36e90cb3c78695e6b09eb97cdb277414 to your computer and use it in GitHub Desktop.

Revisions

  1. alexcpn revised this gist Nov 28, 2024. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion llm_probability2.ipynb
    Original file line number Diff line number Diff line change
    @@ -4,7 +4,8 @@
    "metadata": {
    "colab": {
    "provenance": [],
    "authorship_tag": "ABX9TyNr/UpNAHfOy+HC0ztAfuHW"
    "authorship_tag": "ABX9TyNr/UpNAHfOy+HC0ztAfuHW",
    "include_colab_link": true
    },
    "kernelspec": {
    "name": "python3",
    @@ -15,6 +16,16 @@
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/alexcpn/36e90cb3c78695e6b09eb97cdb277414/llm_probability2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
  2. alexcpn created this gist Nov 28, 2024.
    400 changes: 400 additions & 0 deletions llm_probability2.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,400 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "colab": {
    "provenance": [],
    "authorship_tag": "ABX9TyNr/UpNAHfOy+HC0ztAfuHW"
    },
    "kernelspec": {
    "name": "python3",
    "display_name": "Python 3"
    },
    "language_info": {
    "name": "python"
    }
    },
    "cells": [
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "kDxTsn5QhwKK"
    },
    "outputs": [],
    "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import (\n",
    " AutoModelForCausalLM,\n",
    " AutoTokenizer,\n",
    ")\n"
    ]
    },
    {
    "cell_type": "code",
    "source": [
    "\n",
    "def print_probability_distribution(current, probabilities, tokenizer, top_n=20, bar_width=50):\n",
    " \"\"\"\n",
    " Print the top N tokens and their probabilities in ASCII format.\n",
    "\n",
    " Parameters:\n",
    " - current: Current context as a string.\n",
    " - probabilities: Probability distribution over the vocabulary.\n",
    " - tokenizer: Tokenizer to decode token IDs to tokens.\n",
    " - top_n: Number of top tokens to display.\n",
    " - bar_width: Width of the ASCII bar representing probabilities.\n",
    " \"\"\"\n",
    " # Get top N tokens and their probabilities\n",
    " top_indices = np.argsort(probabilities)[-top_n:][::-1]\n",
    " top_probs = probabilities[top_indices]\n",
    " top_tokens = [tokenizer.decode([i]).strip() for i in top_indices]\n",
    "\n",
    " # Find the next token (highest probability token)\n",
    " max_token = top_tokens[0]\n",
    "\n",
    " # Display the current context\n",
    " print(f\"Context: {current}\")\n",
    " print(f\"Next Token Prediction: '{max_token}'\\n\")\n",
    "\n",
    " # Print the top N tokens and their probabilities as an ASCII bar chart\n",
    " for token, prob in zip(top_tokens, top_probs):\n",
    " bar = \"#\" * int(prob * bar_width)\n",
    " print(f\"{token:>15} | {bar} {prob:.4f}\")\n",
    "\n",
    "def plot_probability_distribution(current, probabilities, tokenizer, top_n=20):\n",
    " # Get top N tokens and their probabilities\n",
    " top_indices = np.argsort(probabilities)[-top_n:][::-1]\n",
    " top_probs = probabilities[top_indices]\n",
    " top_tokens = [tokenizer.decode([i]) for i in top_indices]\n",
    "\n",
    " # Find the next token (highest probability token)\n",
    " max_token = tokenizer.decode([top_indices[0]])\n",
    "\n",
    " # Plot\n",
    " plt.figure(figsize=(12, 7))\n",
    " bars = plt.bar(top_tokens, top_probs, color=\"blue\")\n",
    " bars[0].set_color(\"red\") # Highlight the next token\n",
    "\n",
    " # Add the current context inside the graph\n",
    " plt.text(\n",
    " 0.5,\n",
    " 0.9,\n",
    " f\"Context: {current}\\nNext Token: {max_token}\",\n",
    " ha=\"center\",\n",
    " va=\"center\",\n",
    " transform=plt.gca().transAxes,\n",
    " fontsize=12,\n",
    " bbox=dict(facecolor=\"white\", alpha=0.8, edgecolor=\"black\"),\n",
    " )\n",
    "\n",
    " plt.xlabel(\"Tokens\")\n",
    " plt.ylabel(\"Probabilities\")\n",
    " plt.xticks(rotation=45)\n",
    " plt.tight_layout()\n",
    " plt.show()\n",
    "\n"
    ],
    "metadata": {
    "id": "mxq_nsjshx9N"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "model_name = 'gpt2'\n",
    "#model_name = \"meta-llama/Llama-3.2-1B-Instruct\" # try with this also\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    " model_name,\n",
    " torch_dtype=torch.float16,\n",
    ")\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "model.eval()\n"
    ],
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "gHJuOJHoiIo-",
    "outputId": "dac7e8ca-a337-44d6-84ed-18ec297b8b49"
    },
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "GPT2LMHeadModel(\n",
    " (transformer): GPT2Model(\n",
    " (wte): Embedding(50257, 768)\n",
    " (wpe): Embedding(1024, 768)\n",
    " (drop): Dropout(p=0.1, inplace=False)\n",
    " (h): ModuleList(\n",
    " (0-11): 12 x GPT2Block(\n",
    " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
    " (attn): GPT2SdpaAttention(\n",
    " (c_attn): Conv1D(nf=2304, nx=768)\n",
    " (c_proj): Conv1D(nf=768, nx=768)\n",
    " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
    " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
    " )\n",
    " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
    " (mlp): GPT2MLP(\n",
    " (c_fc): Conv1D(nf=3072, nx=768)\n",
    " (c_proj): Conv1D(nf=768, nx=3072)\n",
    " (act): NewGELUActivation()\n",
    " (dropout): Dropout(p=0.1, inplace=False)\n",
    " )\n",
    " )\n",
    " )\n",
    " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
    " )\n",
    " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
    ")"
    ]
    },
    "metadata": {},
    "execution_count": 8
    }
    ]
    },
    {
    "cell_type": "code",
    "source": [
    "# Get the vocabulary as a dictionary {token: token_id}\n",
    "vocab = tokenizer.get_vocab()\n",
    "# Print the vocabulary size\n",
    "print(f\"Vocabulary Size: {len(vocab)}\")\n",
    "prompt_template = \"I love New\"\n",
    "\n",
    "if model_name == \"meta-llama/Llama-3.2-1B-Instruct\" :\n",
    " # use its format as we are using the Instuct model, the prompt template is as below\n",
    " system_message =\"You complete sentences with funny words\"\n",
    " question = \"Complete the sentence I love New\"\n",
    " prompt_template=f'''\n",
    " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
    " {system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
    " {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
    " '''\n",
    "\n",
    "print(f\"Original Text: {prompt_template}\")\n",
    "input_id_list = list(tokenizer.encode(prompt_template))\n",
    "text =input_id_list\n",
    "generated_tokens = []\n",
    "\n",
    "# Set the number of tokens to generate\n",
    "N = 10\n",
    "\n",
    "# Iterative generation\n",
    "for i in range(N):\n",
    " current_input = torch.tensor([text], dtype=torch.long)\n",
    "\n",
    " # Forward pass to get logits\n",
    " with torch.no_grad():\n",
    " outputs = model(current_input.to(device))\n",
    " logits = outputs.logits\n",
    "\n",
    " # Get probabilities for the last token\n",
    " probabilities = torch.softmax(logits[0, -1], dim=0).cpu().numpy()\n",
    " probabilities /= probabilities.sum() # Normalize\n",
    "\n",
    " # Find the token with the maximum probability\n",
    " max_token_id = np.argmax(probabilities)\n",
    " max_token = tokenizer.decode([max_token_id])\n",
    " generated_tokens.append(max_token)\n",
    "\n",
    " # Append the generated token to the input for the next iteration\n",
    " text.append(max_token_id)\n",
    "\n",
    " # Decode current context for display\n",
    " current = tokenizer.decode(text)\n",
    " print(f\"Decoded Context: {current}\")\n",
    " print(f\"Max Probability Token: '{max_token}' (ID: {max_token_id} word {i})\")\n",
    "\n",
    " # Plot the probability distribution\n",
    " #plot_probability_distribution(current, probabilities, tokenizer, top_n=10)\n",
    " print_probability_distribution(current, probabilities, tokenizer, top_n=10)\n",
    "\n",
    "# Final Output\n",
    "final_generated_text = tokenizer.decode(text)\n",
    "print(f\"\\nFinal Generated Text: {final_generated_text}\")\n"
    ],
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "lHcKnZdvhz9E",
    "outputId": "cdacf974-a11a-487d-d7a4-86d0812450b8"
    },
    "execution_count": null,
    "outputs": [
    {
    "output_type": "stream",
    "name": "stdout",
    "text": [
    "Vocabulary Size: 50257\n",
    "Original Text: I love New\n",
    "Decoded Context: I love New York\n",
    "Max Probability Token: ' York' (ID: 1971 word 0)\n",
    "Context: I love New York\n",
    "Next Token Prediction: 'York'\n",
    "\n",
    " York | ##################### 0.4355\n",
    " Orleans | #### 0.0972\n",
    " Zealand | #### 0.0885\n",
    " England | ## 0.0504\n",
    " Jersey | # 0.0393\n",
    " Year | # 0.0278\n",
    " Yorkers | 0.0191\n",
    " Mexico | 0.0144\n",
    " Hampshire | 0.0096\n",
    " Years | 0.0085\n",
    "Decoded Context: I love New York.\n",
    "Max Probability Token: '.' (ID: 13 word 1)\n",
    "Context: I love New York.\n",
    "Next Token Prediction: '.'\n",
    "\n",
    " . | ########## 0.2020\n",
    " , | ######## 0.1783\n",
    " and | #### 0.0955\n",
    " City | ### 0.0792\n",
    " ! | # 0.0398\n",
    " ,\" | # 0.0351\n",
    " .\" | # 0.0291\n",
    " !\" | # 0.0273\n",
    " so | # 0.0227\n",
    " 's | # 0.0227\n",
    "Decoded Context: I love New York. I\n",
    "Max Probability Token: ' I' (ID: 314 word 2)\n",
    "Context: I love New York. I\n",
    "Next Token Prediction: 'I'\n",
    "\n",
    " I | ################ 0.3269\n",
    " It | ###### 0.1202\n",
    " | ## 0.0533\n",
    " We | ## 0.0471\n",
    " But | ## 0.0416\n",
    " And | # 0.0390\n",
    " The | # 0.0268\n",
    " You | 0.0184\n",
    " So | 0.0163\n",
    " My | 0.0144\n",
    "Decoded Context: I love New York. I love\n",
    "Max Probability Token: ' love' (ID: 1842 word 3)\n",
    "Context: I love New York. I love\n",
    "Next Token Prediction: 'love'\n",
    "\n",
    " love | ###################### 0.4473\n",
    " 'm | ### 0.0686\n",
    " 've | ## 0.0416\n",
    " like | # 0.0324\n",
    " think | # 0.0252\n",
    " have | # 0.0223\n",
    " live | # 0.0223\n",
    " know | 0.0173\n",
    " don | 0.0173\n",
    " want | 0.0153\n",
    "Decoded Context: I love New York. I love the\n",
    "Max Probability Token: ' the' (ID: 262 word 4)\n",
    "Context: I love New York. I love the\n",
    "Next Token Prediction: 'the'\n",
    "\n",
    " the | ########## 0.2021\n",
    " New | ###### 0.1305\n",
    " it | ## 0.0544\n",
    " my | # 0.0330\n",
    " being | # 0.0257\n",
    " this | # 0.0257\n",
    " to | # 0.0227\n",
    " all | # 0.0227\n",
    " that | 0.0188\n",
    " living | 0.0156\n",
    "Decoded Context: I love New York. I love the city\n",
    "Max Probability Token: ' city' (ID: 1748 word 5)\n",
    "Context: I love New York. I love the city\n",
    "Next Token Prediction: 'city'\n",
    "\n",
    " city | ######### 0.1903\n",
    " people | ## 0.0580\n",
    " way | # 0.0213\n",
    " fact | # 0.0213\n",
    " place | 0.0156\n",
    " New | 0.0147\n",
    " music | 0.0138\n",
    " country | 0.0101\n",
    " great | 0.0079\n",
    " culture | 0.0079\n",
    "Decoded Context: I love New York. I love the city.\n",
    "Max Probability Token: '.' (ID: 13 word 6)\n",
    "Context: I love New York. I love the city.\n",
    "Next Token Prediction: '.'\n",
    "\n",
    " . | ######################## 0.4858\n",
    " , | ##### 0.1154\n",
    " and | ### 0.0793\n",
    " .\" | ### 0.0657\n",
    " of | ## 0.0424\n",
    " ,\" | # 0.0363\n",
    " that | # 0.0201\n",
    " I | 0.0156\n",
    " ! | 0.0114\n",
    " !\" | 0.0079\n",
    "Decoded Context: I love New York. I love the city. I\n",
    "Max Probability Token: ' I' (ID: 314 word 7)\n",
    "Context: I love New York. I love the city. I\n",
    "Next Token Prediction: 'I'\n",
    "\n",
    " I | ############################ 0.5669\n",
    " And | ### 0.0767\n",
    " It | ## 0.0527\n",
    " But | ## 0.0437\n",
    " | # 0.0340\n",
    " We | # 0.0206\n",
    " The | 0.0182\n",
    " So | 0.0110\n",
    " New | 0.0110\n",
    " My | 0.0098\n",
    "Decoded Context: I love New York. I love the city. I love\n",
    "Max Probability Token: ' love' (ID: 1842 word 8)\n",
    "Context: I love New York. I love the city. I love\n",
    "Next Token Prediction: 'love'\n",
    "\n",
    " love | ####################################### 0.7949\n",
    " 'm | # 0.0255\n",
    " like | # 0.0240\n",
    " want | 0.0128\n",
    " think | 0.0094\n",
    " really | 0.0073\n",
    " 've | 0.0073\n",
    " know | 0.0069\n",
    " don | 0.0069\n",
    " have | 0.0065\n",
    "Decoded Context: I love New York. I love the city. I love the\n",
    "Max Probability Token: ' the' (ID: 262 word 9)\n",
    "Context: I love New York. I love the city. I love the\n",
    "Next Token Prediction: 'the'\n",
    "\n",
    " the | ################## 0.3635\n",
    " it | ### 0.0632\n",
    " New | ## 0.0594\n",
    " my | # 0.0338\n",
    " all | # 0.0280\n",
    " that | # 0.0280\n",
    " to | # 0.0218\n",
    " how | # 0.0205\n",
    " being | # 0.0205\n",
    " this | 0.0170\n",
    "\n",
    "Final Generated Text: I love New York. I love the city. I love the\n"
    ]
    }
    ]
    }
    ]
    }