Forked from mf1024/Fine-tuning GPT2-medium in PyTorch.ipynb
Created
February 27, 2022 09:22
-
-
Save mivanovitch/e879b72b479acadb3f1e3215bf0235b7 to your computer and use it in GitHub Desktop.
Revisions
-
mf1024 revised this gist
Dec 17, 2019 . No changes.There are no files selected for viewing
-
mf1024 revised this gist
Dec 17, 2019 . 1 changed file with 28 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,7 +6,7 @@ "source": [ "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n", "\n", "This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](https://mf1024.github.io/2019/11/12/Fun-With-GPT-2/). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n", "\n", "Let's see if the model can learn to crack some jokes!\n", "\n", @@ -19,6 +19,7 @@ "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], @@ -41,7 +42,9 @@ { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')\n", @@ -52,7 +55,9 @@ { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def choose_from_top(probs, n=5):\n", @@ -76,7 +81,9 @@ { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from torch.utils.data import Dataset\n", @@ -112,7 +119,9 @@ { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dataset = JokesDataset()\n", @@ -133,7 +142,9 @@ { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "BATCH_SIZE = 16\n", @@ -298,7 +309,9 @@ { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "MODEL_EPOCH = 4\n", @@ -364,28 +377,30 @@ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.14" } }, "nbformat": 4, -
mf1024 revised this gist
Dec 17, 2019 . 1 changed file with 17 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,7 +12,7 @@ "\n", "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n", "\n", "#### If you haven't yet, check out the notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where use the same pretrained model to generate text." ] }, { @@ -68,9 +68,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorch Dataset module for Short jokes dataset\n", "\n", "For fine-tuning the GPT2 model, I will use this [Short Jokes dataset](https://www.kaggle.com/abhinavmoudgil95/short-jokes) published on Kaggle. After each joke, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in a single input sequence." ] }, { @@ -125,9 +125,9 @@ "source": [ "### Hyperparameters\n", "\n", "I tested many(more than 5) hyperparameter sets till I found one that works the best. I mostly tuned ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n", "\n", "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py) huggingface fine-tuning code." ] }, { @@ -331,8 +331,8 @@ " n = 20\n", " else:\n", " n = 3\n", " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word\n", " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence\n", "\n", " if next_token_id in tokenizer.encode('<|endoftext|>'):\n", " joke_finished = True\n", @@ -356,8 +356,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ "3rd epoch model seemed to perform the best.\n", "\n", "The generated jokes output was too long for a notebook, so I stored it in [this file](https://github.com/mf1024/transformers/blob/master/generated_2_jokes.txt)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { -
mf1024 revised this gist
Dec 17, 2019 . 1 changed file with 130 additions and 172 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -40,7 +40,7 @@ }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -51,12 +51,12 @@ }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def choose_from_top(probs, n=5):\n", " ind = np.argpartition(probs, -n)[-n:]\n", " top_prob = probs[ind]\n", " top_prob = top_prob / np.sum(top_prob) # Normalize\n", " choice = np.random.choice(n, 1, p = top_prob)\n", @@ -75,31 +75,33 @@ }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import Dataset\n", "from torch.utils.data import Dataset, DataLoader\n", "import os\n", "import json\n", "import csv\n", "\n", "class JokesDataset(Dataset):\n", " def __init__(self, jokes_dataset_path = 'jokes_data/'):\n", " super().__init__()\n", "\n", " short_jokes_path = os.path.join(jokes_dataset_path, 'shortjokes.csv')\n", "\n", " self.joke_list = []\n", " self.end_of_text_token = \"<|endoftext|>\"\n", " \n", " with open(short_jokes_path) as csv_file:\n", " csv_reader = csv.reader(csv_file, delimiter=',')\n", " \n", " x = 0\n", " for row in csv_reader:\n", " joke_str = f\"JOKE:{row[1]}{self.end_of_text_token}\"\n", " self.joke_list.append(joke_str)\n", " \n", " def __len__(self):\n", " return len(self.joke_list)\n", "\n", @@ -109,7 +111,7 @@ }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -130,14 +132,14 @@ }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 16\n", "EPOCHS = 5\n", "LEARNING_RATE = 3e-5\n", "WARMUP_STEPS = 5000\n", "MAX_SEQ_LEN = 400\n", "from transformers import AdamW, WarmupLinearSchedule\n", "\n", @@ -150,117 +152,73 @@ "cell_type": "markdown", "metadata": {}, "source": [ "### Model training\n", "\n", "I will train the model and save the model weights after each epoch and then I will try to generate jokes with each version of the weight to see which performs the best." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EPOCH 0 started==============================\n", "sum loss 6755.8515625\n", "sum loss 6162.54736328125\n", "sum loss 5615.0087890625\n", "sum loss 5427.8349609375\n", "sum loss 5346.51416015625\n", "sum loss 5282.28466796875\n", "sum loss 5242.89794921875\n", "sum loss 5210.7109375\n", "sum loss 5167.60400390625\n", "EPOCH 1 started==============================\n", "sum loss 5125.49658203125\n", "sum loss 5123.25830078125\n", "sum loss 5083.72998046875\n", "sum loss 5077.42138671875\n", "sum loss 5054.7705078125\n", "sum loss 5012.50830078125\n", "sum loss 5004.54150390625\n", "sum loss 4992.71435546875\n", "sum loss 4983.85205078125\n", "sum loss 4962.859375\n", "EPOCH 2 started==============================\n", "sum loss 4922.1162109375\n", "sum loss 4891.734375\n", "sum loss 4888.96923828125\n", "sum loss 4877.55224609375\n", "sum loss 4866.8857421875\n", "sum loss 4853.67236328125\n", "sum loss 4847.48681640625\n", "sum loss 4848.1748046875\n", "sum loss 4823.41162109375\n", "sum loss 4812.2216796875\n", "EPOCH 3 started==============================\n", "sum loss 4760.50537109375\n", "sum loss 4739.74072265625\n", "sum loss 4735.41943359375\n", "sum loss 4724.58837890625\n", "sum loss 4720.5361328125\n", "sum loss 4702.38671875\n", "sum loss 4703.87744140625\n", "sum loss 4696.990234375\n", "sum loss 4683.61181640625\n", "EPOCH 4 started==============================\n", "sum loss 4663.662109375\n", "sum loss 4584.58251953125\n", "sum loss 4596.220703125\n", "sum loss 4595.849609375\n", "sum loss 4583.38427734375\n", "sum loss 4582.66552734375\n", "sum loss 4574.6728515625\n", "sum loss 4563.45166015625\n", "sum loss 4559.3408203125\n", "sum loss 4564.5859375\n" ] } ], @@ -274,6 +232,9 @@ "batch_count = 0\n", "\n", "tmp_jokes_tens = None\n", "models_folder = \"trained_models\"\n", "if not os.path.exists(models_folder):\n", " os.mkdir(models_folder)\n", "\n", "for epoch in range(EPOCHS):\n", " \n", @@ -316,89 +277,86 @@ " scheduler.step() \n", " optimizer.zero_grad()\n", " model.zero_grad()\n", "\n", " if batch_count == 100:\n", " print(f\"sum loss {sum_loss}\")\n", " batch_count = 0\n", " sum_loss = 0.0\n", " \n", " # Store the model after each epoch to compare the performance of them\n", " torch.save(model.state_dict(), os.path.join(models_folder, f\"gpt2_medium_joker_{epoch}.pt\"))\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating the jokes" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "MODEL_EPOCH = 4\n", "\n", "models_folder = \"trained_models\"\n", "\n", "model_path = os.path.join(models_folder, f\"gpt2_medium_joker_{MODEL_EPOCH}.pt\")\n", "model.load_state_dict(torch.load(model_path))\n", "\n", "jokes_output_file_path = f'generated_{MODEL_EPOCH}.jokes'\n", "\n", "model.eval()\n", "if os.path.exists(jokes_output_file_path):\n", " os.remove(jokes_output_file_path)\n", " \n", "joke_num = 0\n", "with torch.no_grad():\n", " \n", " for joke_idx in range(1000):\n", " \n", " joke_finished = False\n", "\n", " cur_ids = torch.tensor(tokenizer.encode(\"JOKE:\")).unsqueeze(0).to(device)\n", "\n", " for i in range(100):\n", " outputs = model(cur_ids, labels=cur_ids)\n", " loss, logits = outputs[:2]\n", " softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding\n", " if i < 3:\n", " n = 20\n", " else:\n", " n = 3\n", " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) choose the next word from the top n words\n", " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word\n", "\n", " if next_token_id in tokenizer.encode('<|endoftext|>'):\n", " joke_finished = True\n", " break\n", "\n", " \n", " if joke_finished:\n", " \n", " joke_num = joke_num + 1\n", " \n", " output_list = list(cur_ids.squeeze().to('cpu').numpy())\n", " output_text = tokenizer.decode(output_list)\n", "\n", " with open(jokes_output_file_path, 'a') as f:\n", " f.write(f\"{output_text} \\n\\n\")\n", " \n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output was too long, so I stored it in [this file](https://github.com/mf1024/transformers/blob/master/generated_2_jokes.txt)." ] } ], -
mf1024 revised this gist
Dec 12, 2019 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,13 +6,13 @@ "source": [ "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n", "\n", "This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n", "\n", "Let's see if the model can learn to crack some jokes!\n", "\n", "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n", "\n", "#### If you haven't yet, check out the notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model for text generation." ] }, { -
mf1024 revised this gist
Dec 12, 2019 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -10,9 +10,9 @@ "\n", "Let's see if the model can learn to crack some jokes!\n", "\n", "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n", "\n", "#### If you haven't yet, check out the text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model." ] }, { -
mf1024 revised this gist
Dec 12, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,7 +6,7 @@ "source": [ "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n", "\n", "This notebook was create as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n", "\n", "Let's see if the model can learn to crack some jokes!\n", "\n", -
mf1024 revised this gist
Dec 12, 2019 . No changes.There are no files selected for viewing
-
mf1024 revised this gist
Dec 12, 2019 . 1 changed file with 12 additions and 36 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,13 +4,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n", "\n", "This notebook is written for a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n", "\n", "Let's see if the model can learn to crack some jokes!\n", "\n", "For this experiment, I will use a pre-trained GPT2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n", "\n", "#### If you haven't yet, check out the *GPT2LMHeadModel* text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details on setting up and using the pre-trained model." ] }, { @@ -49,12 +51,10 @@ }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def choose_from_top(probs, n=5):\n", " ind = np.argpartition(probs, -n)top[-n:]\n", " top_prob = probs[ind]\n", @@ -70,7 +70,7 @@ "source": [ "### PyTorch Dataset module for Reddit jokes\n", "\n", "For fine-tuning the GPT2 model, I will use Reddit jokes from [this](https://github.com/taivop/joke-dataset/blob/master/reddit_jokes.json) dataset. After each joke sample, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in a single sequence input sequence." ] }, { @@ -117,32 +117,15 @@ "joke_loader = DataLoader(dataset, batch_size=1, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparameters\n", "\n", "I tested many(I think 5) hyperparameter sets till I found one that works the best. I mostly tuned ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n", "\n", "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py) piece of code." ] }, { @@ -383,7 +366,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ "The output was too long, so I stored it in [this file](github.link)." ] }, { @@ -406,7 +389,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ "### Load stored model to generate more jokes" ] }, { @@ -417,13 +400,6 @@ "source": [ "model.load_state_dict(torch.load(\"gpt2_medium_joker.pt\"))" ] } ], "metadata": { -
mf1024 revised this gist
Nov 22, 2019 . 1 changed file with 49 additions and 4022 deletions.There are no files selected for viewing
-
mf1024 created this gist
Nov 21, 2019 .There are no files selected for viewing