Created
March 6, 2022 05:45
-
-
Save amoux/b0ea4db5d5c3f29c8205bd24aee1b47c to your computer and use it in GitHub Desktop.
Revisions
-
amoux revised this gist
Mar 6, 2022 . 1 changed file with 12 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,15 @@ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/gist/amoux/b0ea4db5d5c3f29c8205bd24aee1b47c/char2char-rnn-with-ego.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "code", "execution_count": null, @@ -966,7 +976,8 @@ }, "colab": { "name": "Char2Char RNN with ego 🔥.ipynb", "provenance": [], "include_colab_link": true } }, "nbformat": 4, -
amoux created this gist
Mar 6, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,974 @@ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "d0206cbf", "metadata": { "id": "d0206cbf" }, "outputs": [], "source": [ "import ego\n", "import ego.nn as nn\n", "import ego.nn.functional as F\n", "import ego.optim as optim\n", "import ego.text as text" ] }, { "cell_type": "code", "execution_count": null, "id": "c552c25c", "metadata": { "id": "c552c25c", "outputId": "2777c7b6-f335-4919-c4a7-3de87b90af4b" }, "outputs": [ { "data": { "text/plain": [ "Tokenizer(vocab_size=19, maxlen=15, type=char, base=<CharPreTokenization>)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corpus = ['hey how are you',\n", " 'good i am fine',\n", " 'have a nice day']\n", "\n", "tokenizer = text.Tokenizer('char', min_freq=0, lower_case=True, \n", " use_special_tokens=False).fit(corpus)\n", "tokenizer" ] }, { "cell_type": "code", "execution_count": null, "id": "162f2d20", "metadata": { "id": "162f2d20", "outputId": "0aea901f-e391-4c12-bdc4-04b11f0ac83f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x : hey how are yo\n", "y : ey how are you\n", "x : good i am fine\n", "y : ood i am fine \n", "x : have a nice da\n", "y : ave a nice day\n" ] } ], "source": [ "PAD = \" \"\n", "seqs = corpus.copy()\n", "for i in range(len(seqs)):\n", " while len(seqs[i]) < tokenizer.maxlen:\n", " seqs[i] += PAD\n", "input_seqs = []\n", "target_seqs = []\n", "for i, seq in enumerate(seqs):\n", " assert len(seq) == tokenizer.maxlen\n", " input_seqs.append(seq[:-1])\n", " target_seqs.append(seq[1:])\n", " print(f\"x : {input_seqs[i]}\\ny : {target_seqs[i]}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5c9e72a5", "metadata": { "id": "5c9e72a5", "outputId": "a1b24753-20f1-4a0b-fb0e-1c2d81f0cbbb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 6, 3, 7, 2, 6, 5, 11, 2, 4, 12, 3, 2, 7, 5]])\n", "tensor([[ 3, 7, 2, 6, 5, 11, 2, 4, 12, 3, 2, 7, 5, 13]]) \n", "\n", "['h', 'e', 'y', ' ', 'h', 'o', 'w', ' ', 'a', 'r', 'e', ' ', 'y', 'o']\n", "['e', 'y', ' ', 'h', 'o', 'w', ' ', 'a', 'r', 'e', ' ', 'y', 'o', 'u']\n", "X: ego.Size([3, 14, 19]) | y: ego.Size([3, 14])\n" ] } ], "source": [ "input_ids = tokenizer(input_seqs, return_attention_mask=False,\n", " return_tensors=\"ego\")\n", "target_ids = tokenizer(target_seqs, return_attention_mask=False,\n", " return_tensors=\"ego\")\n", "input_features = F.one_hot(input_ids, num_classes=tokenizer.vocab_size)\n", "\n", "print(input_ids[:1])\n", "print(target_ids[:1], '\\n')\n", "print(tokenizer.decode(input_ids[0]))\n", "print(tokenizer.decode(target_ids[0]))\n", "print(\"X: {} | y: {}\".format(input_features.size(), target_ids.size()))" ] }, { "cell_type": "code", "execution_count": null, "id": "6cba9c06", "metadata": { "id": "6cba9c06", "outputId": "b03275b6-3ce4-4528-b281-a9fbac18146d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Char2CharRNN(\n", " (rnn): GRU(19, 32, batch_first=True)\n", " (fc): Linear(in_features=32, out_features=19, bias=True)\n", ")\n" ] } ], "source": [ "class Char2CharRNN(nn.Module):\n", " def __init__(self, ninp, nout, nhid, nlayer=1, dropout=0, is_batch=True):\n", " super(Char2CharRNN, self).__init__()\n", " self.nhid = nhid\n", " self.rnn = nn.GRU(ninp, nhid, nlayer, bias=True,\n", " batch_first=is_batch, dropout=dropout)\n", " self.fc = nn.Linear(nhid, nout)\n", "\n", " def inference(self):\n", " self.eval() # set self, parent and children modules to eval mode\n", " # set the requires_grad flag to False for all defined parameters\n", " self.requires_grad_(False)\n", " # clear accumulated gradients for all defined parameters (if any)\n", " self.zero_grad(set_to_none=True)\n", "\n", " def forward(self, x):\n", " ih, hh = self.rnn(x)\n", " hidden = ih.view(-1, self.nhid)\n", " logits = self.fc(hidden)\n", " if self.training:\n", " # keep tensor 2D from <Linear>[BxT, Nt] -> <NLLLoss>(x:2D, y:1D)\n", " return logits\n", " else:\n", " # otherwise, reshape as 3D<OneHot>[B, T, Nt]\n", " return logits.reshape_as(x)\n", "\n", "\n", "# vocab_size, hidden_size, num_layers\n", "Nt, H, L = (tokenizer.vocab_size, 32, 1)\n", "\n", "model = Char2CharRNN(Nt, Nt, H, L)\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.RMSprop(model.parameters(), lr=0.01)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": null, "id": "d5a57b12", "metadata": { "scrolled": true, "id": "d5a57b12", "outputId": "4eac87fb-ab4a-41f7-91b3-989bc48118a3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 0/50\t.......... loss : 2.9805\n", "epoch : 1/50\t.......... loss : 2.6823\n", "epoch : 2/50\t.......... loss : 2.5745\n", "epoch : 3/50\t.......... loss : 2.2057\n", "epoch : 4/50\t.......... loss : 1.8892\n", "epoch : 5/50\t.......... loss : 1.5189\n", "epoch : 6/50\t.......... loss : 1.2244\n", "epoch : 7/50\t.......... loss : 1.0763\n", "epoch : 8/50\t.......... loss : 0.9662\n", "epoch : 9/50\t.......... loss : 0.6695\n", "epoch : 10/50\t.......... loss : 0.5876\n", "epoch : 11/50\t.......... loss : 0.4450\n", "epoch : 12/50\t.......... loss : 0.3818\n", "epoch : 13/50\t.......... loss : 0.3738\n", "epoch : 14/50\t.......... loss : 0.2685\n", "epoch : 15/50\t.......... loss : 0.1938\n", "epoch : 16/50\t.......... loss : 0.1572\n", "epoch : 17/50\t.......... loss : 0.1356\n", "epoch : 18/50\t.......... loss : 0.1208\n", "epoch : 19/50\t.......... loss : 0.1094\n", "epoch : 20/50\t.......... loss : 0.1004\n", "epoch : 21/50\t.......... loss : 0.0931\n", "epoch : 22/50\t.......... loss : 0.0872\n", "epoch : 23/50\t.......... loss : 0.0822\n", "epoch : 24/50\t.......... loss : 0.0780\n", "epoch : 25/50\t.......... loss : 0.0743\n", "epoch : 26/50\t.......... loss : 0.0712\n", "epoch : 27/50\t.......... loss : 0.0684\n", "epoch : 28/50\t.......... loss : 0.0660\n", "epoch : 29/50\t.......... loss : 0.0639\n", "epoch : 30/50\t.......... loss : 0.0620\n", "epoch : 31/50\t.......... loss : 0.0602\n", "epoch : 32/50\t.......... loss : 0.0587\n", "epoch : 33/50\t.......... loss : 0.0573\n", "epoch : 34/50\t.......... loss : 0.0560\n", "epoch : 35/50\t.......... loss : 0.0549\n", "epoch : 36/50\t.......... loss : 0.0538\n", "epoch : 37/50\t.......... loss : 0.0528\n", "epoch : 38/50\t.......... loss : 0.0519\n", "epoch : 39/50\t.......... loss : 0.0511\n", "epoch : 40/50\t.......... loss : 0.0503\n", "epoch : 41/50\t.......... loss : 0.0496\n", "epoch : 42/50\t.......... loss : 0.0490\n", "epoch : 43/50\t.......... loss : 0.0483\n", "epoch : 44/50\t.......... loss : 0.0478\n", "epoch : 45/50\t.......... loss : 0.0472\n", "epoch : 46/50\t.......... loss : 0.0467\n", "epoch : 47/50\t.......... loss : 0.0462\n", "epoch : 48/50\t.......... loss : 0.0458\n", "epoch : 49/50\t.......... loss : 0.0453\n" ] } ], "source": [ "epochs = 50\n", "x = input_features\n", "y = target_ids.view(-1)\n", "for epoch in range(epochs):\n", " optimizer.zero_grad()\n", " yhat = model(x)\n", " loss = criterion(yhat, y)\n", " loss.backward()\n", " nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " optimizer.step()\n", " print('epoch : {}/{}\\t..........'.format(epoch, epochs), end=' ')\n", " print('loss : {:.4f}'.format(loss.item()))" ] }, { "cell_type": "code", "execution_count": null, "id": "e8abf015", "metadata": { "id": "e8abf015" }, "outputs": [], "source": [ "model.inference() # set the model in inference mode" ] }, { "cell_type": "code", "execution_count": null, "id": "94233747", "metadata": { "id": "94233747" }, "outputs": [], "source": [ "def print_pred(prompt_text, next_token, score):\n", " print(f'predicted token: {prompt_text} + [{tokenizer[next_token]}] '\n", " f'| ~P(next|prompt) = {score.item():.4f}%')" ] }, { "cell_type": "code", "execution_count": null, "id": "21d6402b", "metadata": { "id": "21d6402b", "outputId": "f196bb44-4ce4-49ee-b8f1-4fa379156381" }, "outputs": [ { "data": { "text/plain": [ "tensor([[[-1.7822, -2.0948, 0.0980, 5.4613, 5.4769, 0.9236, -2.9440, -0.2204,\n", " -0.5596, -1.2992, -1.4227, -1.4685, -1.4688, -3.1646, -2.0938, -2.5651,\n", " -3.2602, 0.2805, -1.9711],\n", " [-2.0338, -1.8651, 2.7062, -1.1563, 0.1298, -0.1714, -1.7824, 7.6752,\n", " -5.0695, -0.5060, -2.5180, 0.4877, -0.9850, -1.4442, -2.2192, 0.2705,\n", " -3.1254, 3.0572, -2.8077]]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt_text = \"he\"\n", "prompt_tensor = tokenizer.encode(prompt_text, return_tensors=\"ego\").unsqueeze(0)\n", "prompt_tensor = F.one_hot(prompt_tensor, tokenizer.vocab_size)\n", "output = model(prompt_tensor)\n", "output" ] }, { "cell_type": "code", "execution_count": null, "id": "9848127f", "metadata": { "id": "9848127f", "outputId": "cbf214cf-30f0-49a4-cf33-941ae31c817e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "predicted token: he + [a] | ~P(next|prompt) = 0.4955%\n", "\n", "\t_________________________________\n", "\n", "predicted token: he + [y] | ~P(next|prompt) = 0.9802%\n" ] } ], "source": [ "ith_layer_0 = 0 # first layer hidden states\n", "ith_layer_1 = 1 # second layer hidden states\n", "\n", "score, next_token = output[:, ith_layer_0].softmax(-1).topk(1)\n", "print_pred(prompt_text, next_token, score)\n", "print('\\n\\t_________________________________\\n')\n", "\n", "score, next_token = output[:, ith_layer_1].softmax(-1).topk(1)\n", "print_pred(prompt_text, next_token, score)" ] }, { "cell_type": "code", "execution_count": null, "id": "d6af77f4", "metadata": { "id": "d6af77f4" }, "outputs": [], "source": [ "def predict(tokens):\n", " if isinstance(tokens, str):\n", " tokens = tokenizer.tokenize(tokens)\n", " token_ids = [tokenizer[token] for token in tokens]\n", " features = F.one_hot(ego.tensor([token_ids]),\n", " num_classes=tokenizer.vocab_size)\n", " with ego.no_grad():\n", " logits = model(features)\n", " score, pred = logits[:, -1, :].softmax(dim=-1).topk(1)\n", " return {'token_id': [int(pred)],\n", " 'token': [tokenizer[pred]], 'score': float(score)}\n", "\n", "\n", "def sample(prompt: str, length=5):\n", " tokens = [char for char in prompt]\n", " length = length - len(tokens)\n", " for _ in range(length):\n", " pred_token = predict(tokens)['token']\n", " tokens.append(pred_token[0])\n", " return \"\".join(tokens)" ] }, { "cell_type": "code", "execution_count": null, "id": "a87a532a", "metadata": { "id": "a87a532a", "outputId": "77814eb2-f9ce-4690-dfb2-10281de8c3f5" }, "outputs": [ { "data": { "text/plain": [ "{'token_id': [7], 'token': ['y'], 'score': 0.9802330732345581}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"he\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7efb2f84", "metadata": { "id": "7efb2f84", "outputId": "d3952dad-db66-4d48-ee7f-9e7e0fea5c07" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "good i am fine \n" ] } ], "source": [ "print(sample(\"good\", tokenizer.maxlen))" ] }, { "cell_type": "code", "execution_count": null, "id": "fe42803a", "metadata": { "id": "fe42803a", "outputId": "6d01871f-cf42-4091-bb91-22fa98f69f03" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hey how are you\n" ] } ], "source": [ "print(sample(\"hey\", tokenizer.maxlen))" ] }, { "cell_type": "code", "execution_count": null, "id": "c50a0097", "metadata": { "id": "c50a0097", "outputId": "dc2f169d-04de-4583-dd51-1a3ab70bc671" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "have a nice day\n" ] } ], "source": [ "print(sample(\"have\", tokenizer.maxlen))" ] }, { "cell_type": "code", "execution_count": null, "id": "cf844af1", "metadata": { "id": "cf844af1", "outputId": "0695753b-c001-4d12-acae-f1496b411623" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hey i have youu i am \n" ] } ], "source": [ "print(sample(\"hey i have\", tokenizer.maxlen + 6))" ] }, { "cell_type": "code", "execution_count": null, "id": "a403e0cf", "metadata": { "id": "a403e0cf", "outputId": "98828421-7c30-4fb9-eb64-749de74f8fdf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "good i am fine and i am fine a nice\n" ] } ], "source": [ "print(sample(\"good i am fine and\", tokenizer.maxlen + 20))" ] }, { "cell_type": "code", "execution_count": null, "id": "0fe22588", "metadata": { "scrolled": true, "id": "0fe22588", "outputId": "7d824784-1d9f-47c5-e8f6-2bcb7b10cbf0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0156 ms ((3, 14, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0083 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2861 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0082 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3554 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0269 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0592 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0211 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0461 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0264 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0115 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0233 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0160 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0079 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2772 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0084 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1313 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0217 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0420 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0191 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0374 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0212 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0094 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0155 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0150 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0076 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1590 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0075 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1783 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0198 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0446 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0193 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0389 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0210 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0095 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0154 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0150 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0110 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3797 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0077 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3453 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0248 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0477 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0246 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0383 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0229 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0105 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0155 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0162 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0102 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2908 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0079 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2809 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0231 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0467 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0190 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0385 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0224 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0102 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0156 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0149 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0076 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.8516 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0079 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2770 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0204 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0413 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0187 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0382 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0223 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0102 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0161 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0148 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0077 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.5585 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0107 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.2777 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0274 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0401 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0182 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0382 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0205 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0092 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0147 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0142 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0071 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.6203 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0074 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1616 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0204 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0465 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0180 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0357 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0199 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0096 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0148 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0147 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0076 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 1.0029 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0072 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1903 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0193 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0406 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0180 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0364 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0204 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0096 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0151 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0140 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0072 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.7133 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0072 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3172 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0191 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0397 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0178 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0365 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0212 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0095 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0151 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0141 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0070 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.5525 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0071 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3067 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0188 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0396 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0188 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0377 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0216 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0098 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0290 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0157 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0077 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.8759 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0073 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.3185 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0200 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0404 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0182 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0362 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0210 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0094 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0150 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0141 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0070 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.8999 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0118 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.4159 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0338 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0729 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0299 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0589 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0312 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0154 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0285 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0272 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0165 ms ((96, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1744 ms ((96,), (3, 19), (19, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0078 ms ((96, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.1335 ms ((96,), (3, 32), (32, 96), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0255 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0483 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Add (n_inp[2], n_out[1]) 0.0190 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Sigmoid (n_inp[1], n_out[1]) 0.0351 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0243 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Tanh (n_inp[1], n_out[1]) 0.0162 ms ((3, 32),)\u001b[0m\n", "\u001b[1m\u001b[36m Sub (n_inp[2], n_out[1]) 0.0177 ms ((3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Addcmul (n_inp[4], n_out[1]) 0.0217 ms ((3, 32), (3, 32), (3, 32), ('float',))\u001b[0m\n", "\u001b[1m\u001b[36m Concat (n_inp[16], n_out[1]) 1.2842 ms (('int',), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m View (n_inp[2], n_out[1]) 0.0153 ms ((42, 32), ('tuple',))\u001b[0m\n", "\u001b[1m\u001b[36m Concat (n_inp[3], n_out[1]) 0.0191 ms (('int',), (14, 3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m Concat (n_inp[3], n_out[1]) 0.0139 ms (('int',), (3, 32))\u001b[0m\n", "\u001b[1m\u001b[36m View (n_inp[2], n_out[1]) 0.0089 ms ((3, 32), ('tuple',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0095 ms ((14, 3, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m View (n_inp[2], n_out[1]) 0.0160 ms ((3, 14, 32), ('tuple',))\u001b[0m\n", "\u001b[1m\u001b[36m Transpose (n_inp[3], n_out[1]) 0.0074 ms ((19, 32), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Addmm (n_inp[5], n_out[1]) 0.6025 ms ((19,), (42, 32), (32, 19), ('int',), ('int',))\u001b[0m\n", "\u001b[1m\u001b[36m Reshape (n_inp[2], n_out[1]) 0.0088 ms ((42, 19), (3, 14, 19))\u001b[0m\n" ] } ], "source": [ "# Forward graph computation for input with batch_size 3\n", "with ego.enable_profile():\n", " logits = model(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "32d2c8ed", "metadata": { "id": "32d2c8ed", "outputId": "eb458ff0-e73c-4c8a-db3a-343abe7c74fb" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.1074, 0.1042, -0.5305, -1.6488, -3.4726, 2.2977, 1.0623, 1.2069,\n", " -1.5110, 1.4202, -0.3740, 2.7155, -0.7314, 8.2732, 0.5616, -1.2247,\n", " -3.5442, -2.0093, -0.8996],\n", " [-2.3912, -1.8889, 9.2165, -0.9056, -1.6622, -0.2876, -3.7539, 1.7835,\n", " -4.3580, -0.9843, 0.1061, 0.8117, -3.0761, -1.5708, -2.2181, -0.0827,\n", " -1.1877, -1.8860, 1.5904],\n", " [-2.5384, -1.8931, 1.7497, -0.5150, -0.3946, -2.7083, -0.3823, 7.4448,\n", " -0.7722, -3.3754, 0.9520, 0.1721, 1.2855, -0.6555, -2.6928, 0.0379,\n", " -2.8825, 0.6843, -3.4892]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits[:, -1, :] # raw token logits (dim=1) per batch (dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "913d2ca8", "metadata": { "id": "913d2ca8", "outputId": "249a2eac-1176-414e-d990-c78d8a10f986" }, "outputs": [ { "data": { "text/plain": [ "tensor([[2.2683e-04, 2.8028e-04, 1.4857e-04, 4.8559e-05, 7.8380e-06, 2.5133e-03,\n", " 7.3064e-04, 8.4427e-04, 5.5734e-05, 1.0450e-03, 1.7374e-04, 3.8168e-03,\n", " 1.2154e-04, 9.8933e-01, 4.4284e-04, 7.4210e-05, 7.2967e-06, 3.3864e-05,\n", " 1.0272e-04],\n", " [9.0792e-06, 1.5004e-05, 9.9822e-01, 4.0107e-05, 1.8822e-05, 7.4410e-05,\n", " 2.3239e-06, 5.9035e-04, 1.2703e-06, 3.7073e-05, 1.1031e-04, 2.2338e-04,\n", " 4.5771e-06, 2.0623e-05, 1.0795e-05, 9.1336e-05, 3.0250e-05, 1.5047e-05,\n", " 4.8669e-04],\n", " [4.5644e-05, 8.7029e-05, 3.3244e-03, 3.4527e-04, 3.8945e-04, 3.8515e-05,\n", " 3.9426e-04, 9.8868e-01, 2.6695e-04, 1.9764e-05, 1.4971e-03, 6.8638e-04,\n", " 2.0898e-03, 2.9999e-04, 3.9113e-05, 6.0017e-04, 3.2358e-05, 1.1456e-03,\n", " 1.7640e-05]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores = logits[:, -1, :].softmax(-1) # logits to token probabilities\n", "scores" ] }, { "cell_type": "code", "execution_count": null, "id": "12e93e07", "metadata": { "id": "12e93e07", "outputId": "1444b9a8-5846-4ebf-b85a-aa1e1d761f3b" }, "outputs": [ { "data": { "text/plain": [ "ego.return_types.topk(\n", "values=tensor([[0.9893],\n", " [0.9982],\n", " [0.9887]]),\n", "indices=tensor([[13],\n", " [ 2],\n", " [ 7]]))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores.topk(1) # predicted token likelihood(%) and token_id(n)" ] }, { "cell_type": "code", "execution_count": null, "id": "1ccf2904", "metadata": { "id": "1ccf2904", "outputId": "4c7c8c1b-748e-4e6f-9320-4b28b9a248e9" }, "outputs": [ { "data": { "text/plain": [ "['u', ' ', 'y']" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next_tokens = tokenizer.decode(scores.topk(1).indices.view(-1))\n", "next_tokens # next token per batch" ] }, { "cell_type": "code", "execution_count": null, "id": "bb5fa37c", "metadata": { "id": "bb5fa37c", "outputId": "165456b3-9e1f-4607-8c7f-c47490f6e878" }, "outputs": [ { "data": { "text/plain": [ "['hey how are yo', 'good i am fine', 'have a nice da']" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "decoded_sentences = tokenizer.batch_decode(x.argmax(dim=-1), return_tokens=False)\n", "decoded_sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "61ccd821", "metadata": { "id": "61ccd821", "outputId": "ad5d62e6-2dee-41c3-f44d-cd1e7d4e16cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "joined_sentence_next_token = 'hey how are you'\n", "joined_sentence_next_token = 'good i am fine '\n", "joined_sentence_next_token = 'have a nice day'\n" ] } ], "source": [ "# reconstructing the original sentences.\n", "for sentence, next_token in zip(decoded_sentences, next_tokens):\n", " joined_sentence_next_token = sentence + next_token\n", " print(f\"{joined_sentence_next_token = }\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7866e27c", "metadata": { "id": "7866e27c" }, "outputs": [], "source": [ "" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false }, "colab": { "name": "Char2Char RNN with ego 🔥.ipynb", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }