Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save mivanovitch/e879b72b479acadb3f1e3215bf0235b7 to your computer and use it in GitHub Desktop.

Select an option

Save mivanovitch/e879b72b479acadb3f1e3215bf0235b7 to your computer and use it in GitHub Desktop.

Revisions

  1. @mf1024 mf1024 revised this gist Dec 17, 2019. No changes.
  2. @mf1024 mf1024 revised this gist Dec 17, 2019. 1 changed file with 28 additions and 13 deletions.
    41 changes: 28 additions & 13 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -6,7 +6,7 @@
    "source": [
    "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n",
    "\n",
    "This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](https://mf1024.github.io/2019/11/12/Fun-With-GPT-2/). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "\n",
    "Let's see if the model can learn to crack some jokes!\n",
    "\n",
    @@ -19,6 +19,7 @@
    "cell_type": "code",
    "execution_count": 2,
    "metadata": {
    "collapsed": true,
    "scrolled": true
    },
    "outputs": [],
    @@ -41,7 +42,9 @@
    {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')\n",
    @@ -52,7 +55,9 @@
    {
    "cell_type": "code",
    "execution_count": 4,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "def choose_from_top(probs, n=5):\n",
    @@ -76,7 +81,9 @@
    {
    "cell_type": "code",
    "execution_count": 5,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "from torch.utils.data import Dataset\n",
    @@ -112,7 +119,9 @@
    {
    "cell_type": "code",
    "execution_count": 6,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "dataset = JokesDataset()\n",
    @@ -133,7 +142,9 @@
    {
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "BATCH_SIZE = 16\n",
    @@ -298,7 +309,9 @@
    {
    "cell_type": "code",
    "execution_count": 14,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "MODEL_EPOCH = 4\n",
    @@ -364,28 +377,30 @@
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3",
    "display_name": "Python 2",
    "language": "python",
    "name": "python3"
    "name": "python2"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    "version": 2
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.6.8"
    "pygments_lexer": "ipython2",
    "version": "2.7.14"
    }
    },
    "nbformat": 4,
  3. @mf1024 mf1024 revised this gist Dec 17, 2019. 1 changed file with 17 additions and 8 deletions.
    25 changes: 17 additions & 8 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@
    "\n",
    "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "\n",
    "#### If you haven't yet, check out the notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model for text generation."
    "#### If you haven't yet, check out the notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where use the same pretrained model to generate text."
    ]
    },
    {
    @@ -68,9 +68,9 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### PyTorch Dataset module for Reddit jokes\n",
    "### PyTorch Dataset module for Short jokes dataset\n",
    "\n",
    "For fine-tuning the GPT2 model, I will use Reddit jokes from [this](https://github.com/taivop/joke-dataset/blob/master/reddit_jokes.json) dataset. After each joke sample, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in a single sequence input sequence."
    "For fine-tuning the GPT2 model, I will use this [Short Jokes dataset](https://www.kaggle.com/abhinavmoudgil95/short-jokes) published on Kaggle. After each joke, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in a single input sequence."
    ]
    },
    {
    @@ -125,9 +125,9 @@
    "source": [
    "### Hyperparameters\n",
    "\n",
    "I tested many(I think 5) hyperparameter sets till I found one that works the best. I mostly tuned ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n",
    "I tested many(more than 5) hyperparameter sets till I found one that works the best. I mostly tuned ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n",
    "\n",
    "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py) piece of code."
    "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py) huggingface fine-tuning code."
    ]
    },
    {
    @@ -331,8 +331,8 @@
    " n = 20\n",
    " else:\n",
    " n = 3\n",
    " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) choose the next word from the top n words\n",
    " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word\n",
    " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word\n",
    " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence\n",
    "\n",
    " if next_token_id in tokenizer.encode('<|endoftext|>'):\n",
    " joke_finished = True\n",
    @@ -356,8 +356,17 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "The output was too long, so I stored it in [this file](https://github.com/mf1024/transformers/blob/master/generated_2_jokes.txt)."
    "3rd epoch model seemed to perform the best.\n",
    "\n",
    "The generated jokes output was too long for a notebook, so I stored it in [this file](https://github.com/mf1024/transformers/blob/master/generated_2_jokes.txt)."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
  4. @mf1024 mf1024 revised this gist Dec 17, 2019. 1 changed file with 130 additions and 172 deletions.
    302 changes: 130 additions & 172 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -40,7 +40,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
    @@ -51,12 +51,12 @@
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
    "def choose_from_top(probs, n=5):\n",
    " ind = np.argpartition(probs, -n)top[-n:]\n",
    " ind = np.argpartition(probs, -n)[-n:]\n",
    " top_prob = probs[ind]\n",
    " top_prob = top_prob / np.sum(top_prob) # Normalize\n",
    " choice = np.random.choice(n, 1, p = top_prob)\n",
    @@ -75,31 +75,33 @@
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import os\n",
    "import json\n",
    "import csv\n",
    "\n",
    "class JokesDataset(Dataset):\n",
    " def __init__(self, jokes_dataset_path = 'jokes_data/'):\n",
    " super().__init__()\n",
    "\n",
    " reddit_jokes_path = os.path.join(jokes_dataset_path, 'reddit_jokes.json')\n",
    "\n",
    " with open(reddit_jokes_path) as f:\n",
    " data = json.load(f)\n",
    " short_jokes_path = os.path.join(jokes_dataset_path, 'shortjokes.csv')\n",
    "\n",
    " self.joke_list = []\n",
    " self.end_of_text_token = \"<|endoftext|>\"\n",
    "\n",
    " for idx, joke_json in enumerate(data):\n",
    " joke_str = f\"{self.end_of_text_token}START:{joke_json['title']} {joke_json['body']}{self.end_of_text_token}\"\n",
    " self.joke_list.append(joke_str)\n",
    " \n",
    " with open(short_jokes_path) as csv_file:\n",
    " csv_reader = csv.reader(csv_file, delimiter=',')\n",
    " \n",
    " x = 0\n",
    " for row in csv_reader:\n",
    " joke_str = f\"JOKE:{row[1]}{self.end_of_text_token}\"\n",
    " self.joke_list.append(joke_str)\n",
    " \n",
    " def __len__(self):\n",
    " return len(self.joke_list)\n",
    "\n",
    @@ -109,7 +111,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": 9,
    "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
    @@ -130,14 +132,14 @@
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
    "BATCH_SIZE = 8\n",
    "EPOCHS = 3\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 5\n",
    "LEARNING_RATE = 3e-5\n",
    "WARMUP_STEPS = 10000\n",
    "WARMUP_STEPS = 5000\n",
    "MAX_SEQ_LEN = 400\n",
    "from transformers import AdamW, WarmupLinearSchedule\n",
    "\n",
    @@ -150,117 +152,73 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Model training"
    "### Model training\n",
    "\n",
    "I will train the model and save the model weights after each epoch and then I will try to generate jokes with each version of the weight to see which performs the best."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 12,
    "execution_count": 8,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "EPOCH 0 started==============================\n",
    "sum loss 3436.03076171875\n",
    "sum loss 3271.525390625\n",
    "sum loss 3117.43603515625\n",
    "sum loss 3020.50732421875\n",
    "sum loss 2940.70654296875\n",
    "sum loss 2915.05078125\n",
    "sum loss 2895.248779296875\n",
    "sum loss 2849.1494140625\n",
    "sum loss 2863.0771484375\n",
    "sum loss 2827.261474609375\n",
    "sum loss 2824.12109375\n",
    "sum loss 2795.527587890625\n",
    "sum loss 2803.104248046875\n",
    "sum loss 2802.185791015625\n",
    "sum loss 2786.28515625\n",
    "sum loss 2778.56982421875\n",
    "sum loss 2762.30615234375\n",
    "sum loss 2770.957763671875\n",
    "sum loss 2754.240478515625\n",
    "sum loss 2747.343994140625\n",
    "sum loss 2726.651611328125\n",
    "sum loss 2729.69189453125\n",
    "sum loss 2744.18212890625\n",
    "sum loss 2724.10400390625\n",
    "sum loss 2714.15283203125\n",
    "sum loss 2711.803955078125\n",
    "sum loss 2703.7119140625\n",
    "sum loss 2704.89208984375\n",
    "sum loss 2697.152099609375\n",
    "sum loss 2689.19189453125\n",
    "sum loss 2671.143798828125\n",
    "sum loss 2677.786376953125\n",
    "sum loss 6755.8515625\n",
    "sum loss 6162.54736328125\n",
    "sum loss 5615.0087890625\n",
    "sum loss 5427.8349609375\n",
    "sum loss 5346.51416015625\n",
    "sum loss 5282.28466796875\n",
    "sum loss 5242.89794921875\n",
    "sum loss 5210.7109375\n",
    "sum loss 5167.60400390625\n",
    "EPOCH 1 started==============================\n",
    "sum loss 2645.66064453125\n",
    "sum loss 2622.15966796875\n",
    "sum loss 2662.0419921875\n",
    "sum loss 2641.111083984375\n",
    "sum loss 2650.60546875\n",
    "sum loss 2637.7275390625\n",
    "sum loss 2634.77392578125\n",
    "sum loss 2637.45947265625\n",
    "sum loss 2636.165771484375\n",
    "sum loss 2640.04833984375\n",
    "sum loss 2617.153564453125\n",
    "sum loss 2619.0419921875\n",
    "sum loss 2604.137451171875\n",
    "sum loss 2599.365966796875\n",
    "sum loss 2615.25048828125\n",
    "sum loss 2592.6669921875\n",
    "sum loss 2599.585693359375\n",
    "sum loss 2598.38671875\n",
    "sum loss 2583.61376953125\n",
    "sum loss 2582.79296875\n",
    "sum loss 2598.868408203125\n",
    "sum loss 2579.271728515625\n",
    "sum loss 2582.907470703125\n",
    "sum loss 2582.181396484375\n",
    "sum loss 2595.12353515625\n",
    "sum loss 2563.3701171875\n",
    "sum loss 2551.559814453125\n",
    "sum loss 2555.54833984375\n",
    "sum loss 2581.083740234375\n",
    "sum loss 2553.90234375\n",
    "sum loss 2576.71240234375\n",
    "sum loss 2559.1435546875\n",
    "sum loss 5125.49658203125\n",
    "sum loss 5123.25830078125\n",
    "sum loss 5083.72998046875\n",
    "sum loss 5077.42138671875\n",
    "sum loss 5054.7705078125\n",
    "sum loss 5012.50830078125\n",
    "sum loss 5004.54150390625\n",
    "sum loss 4992.71435546875\n",
    "sum loss 4983.85205078125\n",
    "sum loss 4962.859375\n",
    "EPOCH 2 started==============================\n",
    "sum loss 2484.018798828125\n",
    "sum loss 2495.583984375\n",
    "sum loss 2504.78466796875\n",
    "sum loss 2515.67822265625\n",
    "sum loss 2489.676025390625\n",
    "sum loss 2484.712890625\n",
    "sum loss 2504.12109375\n",
    "sum loss 2473.819580078125\n",
    "sum loss 2479.34326171875\n",
    "sum loss 2486.447265625\n",
    "sum loss 2476.236328125\n",
    "sum loss 2456.914794921875\n",
    "sum loss 2465.84375\n",
    "sum loss 2499.381591796875\n",
    "sum loss 2472.593994140625\n",
    "sum loss 2437.755126953125\n",
    "sum loss 2472.271728515625\n",
    "sum loss 2451.429931640625\n",
    "sum loss 2459.23291015625\n",
    "sum loss 2483.620361328125\n",
    "sum loss 2445.205322265625\n",
    "sum loss 2467.73095703125\n",
    "sum loss 2455.337158203125\n",
    "sum loss 2473.5849609375\n",
    "sum loss 2487.701416015625\n",
    "sum loss 2458.47509765625\n",
    "sum loss 2455.560302734375\n",
    "sum loss 2471.487060546875\n",
    "sum loss 2481.9384765625\n",
    "sum loss 2469.49072265625\n",
    "sum loss 2475.31884765625\n",
    "sum loss 2467.7421875\n"
    "sum loss 4922.1162109375\n",
    "sum loss 4891.734375\n",
    "sum loss 4888.96923828125\n",
    "sum loss 4877.55224609375\n",
    "sum loss 4866.8857421875\n",
    "sum loss 4853.67236328125\n",
    "sum loss 4847.48681640625\n",
    "sum loss 4848.1748046875\n",
    "sum loss 4823.41162109375\n",
    "sum loss 4812.2216796875\n",
    "EPOCH 3 started==============================\n",
    "sum loss 4760.50537109375\n",
    "sum loss 4739.74072265625\n",
    "sum loss 4735.41943359375\n",
    "sum loss 4724.58837890625\n",
    "sum loss 4720.5361328125\n",
    "sum loss 4702.38671875\n",
    "sum loss 4703.87744140625\n",
    "sum loss 4696.990234375\n",
    "sum loss 4683.61181640625\n",
    "EPOCH 4 started==============================\n",
    "sum loss 4663.662109375\n",
    "sum loss 4584.58251953125\n",
    "sum loss 4596.220703125\n",
    "sum loss 4595.849609375\n",
    "sum loss 4583.38427734375\n",
    "sum loss 4582.66552734375\n",
    "sum loss 4574.6728515625\n",
    "sum loss 4563.45166015625\n",
    "sum loss 4559.3408203125\n",
    "sum loss 4564.5859375\n"
    ]
    }
    ],
    @@ -274,6 +232,9 @@
    "batch_count = 0\n",
    "\n",
    "tmp_jokes_tens = None\n",
    "models_folder = \"trained_models\"\n",
    "if not os.path.exists(models_folder):\n",
    " os.mkdir(models_folder)\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    " \n",
    @@ -316,89 +277,86 @@
    " scheduler.step() \n",
    " optimizer.zero_grad()\n",
    " model.zero_grad()\n",
    " \n",
    " if batch_count == 1000:\n",
    "\n",
    " if batch_count == 100:\n",
    " print(f\"sum loss {sum_loss}\")\n",
    " batch_count = 0\n",
    " sum_loss = 0.0"
    " sum_loss = 0.0\n",
    " \n",
    " # Store the model after each epoch to compare the performance of them\n",
    " torch.save(model.state_dict(), os.path.join(models_folder, f\"gpt2_medium_joker_{epoch}.pt\"))\n",
    " "
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Generating some jokes"
    "### Generating the jokes"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 14,
    "metadata": {},
    "outputs": [],
    "source": [
    "MODEL_EPOCH = 4\n",
    "\n",
    "models_folder = \"trained_models\"\n",
    "\n",
    "model_path = os.path.join(models_folder, f\"gpt2_medium_joker_{MODEL_EPOCH}.pt\")\n",
    "model.load_state_dict(torch.load(model_path))\n",
    "\n",
    "jokes_output_file_path = f'generated_{MODEL_EPOCH}.jokes'\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "if os.path.exists(jokes_output_file_path):\n",
    " os.remove(jokes_output_file_path)\n",
    " \n",
    " for joke_idx in range(500):\n",
    "\n",
    " cur_ids = torch.tensor(tokenizer.encode(\"<|endoftext|>START:\")).unsqueeze(0).to(device)\n",
    "joke_num = 0\n",
    "with torch.no_grad():\n",
    " \n",
    " for joke_idx in range(1000):\n",
    " \n",
    " for i in range(250):\n",
    " outputs = model(cur_ids, labels=cur_ids)\n",
    " loss, logits = outputs[:2]\n",
    " softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(only one in this case) batch and the last predicted embedding\n",
    " if i < 2:\n",
    " n = 15\n",
    " else:\n",
    " n = 3\n",
    " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) choose the next word from the top n words\n",
    " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word\n",
    " joke_finished = False\n",
    "\n",
    " cur_ids = torch.tensor(tokenizer.encode(\"JOKE:\")).unsqueeze(0).to(device)\n",
    "\n",
    " for i in range(100):\n",
    " outputs = model(cur_ids, labels=cur_ids)\n",
    " loss, logits = outputs[:2]\n",
    " softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding\n",
    " if i < 3:\n",
    " n = 20\n",
    " else:\n",
    " n = 3\n",
    " next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) choose the next word from the top n words\n",
    " cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word\n",
    "\n",
    " if next_token_id in tokenizer.encode('<|endoftext|>'):\n",
    " joke_finished = True\n",
    " break\n",
    "\n",
    " if next_token_id in tokenizer.encode('<|endoftext|>'):\n",
    " break\n",
    " \n",
    " output_list = list(cur_ids.squeeze().to('cpu').numpy())\n",
    " output_text = tokenizer.decode(output_list)\n",
    " print(f\"JOKE NR {joke_idx}: {output_text} \\n\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "The output was too long, so I stored it in [this file](github.link)."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Store the model for later use"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
    "torch.save(model.state_dict(), os.path.join(\"gpt2_medium_joker.pt\"))"
    " if joke_finished:\n",
    " \n",
    " joke_num = joke_num + 1\n",
    " \n",
    " output_list = list(cur_ids.squeeze().to('cpu').numpy())\n",
    " output_text = tokenizer.decode(output_list)\n",
    "\n",
    " with open(jokes_output_file_path, 'a') as f:\n",
    " f.write(f\"{output_text} \\n\\n\")\n",
    " \n",
    " "
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Load stored model to generate more jokes"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
    "model.load_state_dict(torch.load(\"gpt2_medium_joker.pt\"))"
    "The output was too long, so I stored it in [this file](https://github.com/mf1024/transformers/blob/master/generated_2_jokes.txt)."
    ]
    }
    ],
  5. @mf1024 mf1024 revised this gist Dec 12, 2019. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -6,13 +6,13 @@
    "source": [
    "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n",
    "\n",
    "This notebook was create as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "This notebook was created as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "\n",
    "Let's see if the model can learn to crack some jokes!\n",
    "\n",
    "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "\n",
    "#### If you haven't yet, check out the text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model."
    "#### If you haven't yet, check out the notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model for text generation."
    ]
    },
    {
  6. @mf1024 mf1024 revised this gist Dec 12, 2019. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -10,9 +10,9 @@
    "\n",
    "Let's see if the model can learn to crack some jokes!\n",
    "\n",
    "For this experiment, I will use a pre-trained GPT2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "For this experiment, I will use a pre-trained GPT-2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "\n",
    "#### If you haven't yet, check out the *GPT2LMHeadModel* text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details on setting up and using the pre-trained model."
    "#### If you haven't yet, check out the text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details about setting up and using the pre-trained model."
    ]
    },
    {
  7. @mf1024 mf1024 revised this gist Dec 12, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -6,7 +6,7 @@
    "source": [
    "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n",
    "\n",
    "This notebook is written for a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "This notebook was create as a part of a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "\n",
    "Let's see if the model can learn to crack some jokes!\n",
    "\n",
  8. @mf1024 mf1024 revised this gist Dec 12, 2019. No changes.
  9. @mf1024 mf1024 revised this gist Dec 12, 2019. 1 changed file with 12 additions and 36 deletions.
    48 changes: 12 additions & 36 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -4,13 +4,15 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "# Fine-tuning GPT2-medium on a jokes dataset in PyTorch\n",
    "# Fine-tuning GPT-2 on a jokes dataset in PyTorch\n",
    "\n",
    "This is an experimental notebook for fine-tuning pre-trained GPT2-medium model on a jokes dataset. Let's see if it can learn to crack some jokes. \n",
    "This notebook is written for a blog post - [Fine-tuning large Transformer models on a single GPU in PyTorch - Teaching GPT-2 a sense of humor](www). Here I demonstrate how to fine-tune a pre-trained GPT-2 model on a jokes dataset. \n",
    "\n",
    "For this purpose, I will use the pre-trained GPT2 medium size model from huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "Let's see if the model can learn to crack some jokes!\n",
    "\n",
    "#### First, check out the *GPT2LMHeadModel* text generation experiments in this [gist](www.com). "
    "For this experiment, I will use a pre-trained GPT2 medium-sized model from the huggingface [transformers repository](https://github.com/huggingface/transformers).\n",
    "\n",
    "#### If you haven't yet, check out the *GPT2LMHeadModel* text generation notebook in this [gist](https://gist.github.com/mf1024/430d7fd6ff527350d3e4b5bda0d8614e) where you will find some more details on setting up and using the pre-trained model."
    ]
    },
    {
    @@ -49,12 +51,10 @@
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
    "# Function to select topN() tokens from the probability list(p) and then based on the P(p|top) distribution\n",
    "# select random element\n",
    "def choose_from_top(probs, n=5):\n",
    " ind = np.argpartition(probs, -n)top[-n:]\n",
    " top_prob = probs[ind]\n",
    @@ -70,7 +70,7 @@
    "source": [
    "### PyTorch Dataset module for Reddit jokes\n",
    "\n",
    "For fine-tuning the GPT2 model, I will use [this](https://github.com/taivop/joke-dataset/blob/master/reddit_jokes.json) jokes dataset. After each joke sample, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in one sequence."
    "For fine-tuning the GPT2 model, I will use Reddit jokes from [this](https://github.com/taivop/joke-dataset/blob/master/reddit_jokes.json) dataset. After each joke sample, I add \"<|endofext|>\" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in a single sequence input sequence."
    ]
    },
    {
    @@ -117,32 +117,15 @@
    "joke_loader = DataLoader(dataset, batch_size=1, shuffle=True)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Fine-tuning GPT2-medium on a single GPU\n",
    " \n",
    "Large Transformer models are usually trained in multi-GPU(or TPU) settings because training on reasonable batch size and sequence length requires lots of [tensor|graphical] processing unit memory. My machine is equipped with a single GeForce 1080 Ti, which has 11 GB of memory. By empirical tests on the GPT2 medium model, I found that the maximum total sequence element count in all batches for my GPU to backprop trough is approximately 550, which is not a lot and might not be sufficient for successful fine-tuning. \n",
    "\n",
    "But there are some things we can take into account and improve the situation. \n",
    "\n",
    "The first thing to notice is that batch size in forward/ backward pass of the Transformer does not play a role because [Layer Normalization](https://arxiv.org/abs/1607.06450) is used instead of Batch Normalization. In [Layer Normalization](https://mlexplained.com/2018/11/30/an-overview-of-normalization-methods-in-deep-learning/), each feature is normalized across the feature dimension. \n",
    "\n",
    "Second, we can collect gradients over multiple forward-backward passes, and only then do the model weight update(optimization step). This way, we can store in the memory of the GPU a computational graph of one sequence at a time instead of storing a computational graph of all of the batch. With this strategy, we can get the same result as if the batch would have been processed in a single forward/backward pass, only with *BATCH_SIZE* times less memory.\n",
    "\n",
    "Putting it all together - I will process one sequence at a time with a maximum length of 400. The length of joke sequences varies a lot in the dataset I use, and to make the total sequence element count in one optimization step more consistent, I will try to fit in as many jokes as possible in the 400 element sequence. "
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Hyperparameters\n",
    "\n",
    "I tested many(I think 5) hyperparameter sets till I found one that works the best. I mostly changed ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n",
    "I tested many(I think 5) hyperparameter sets till I found one that works the best. I mostly tuned ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.\n",
    "\n",
    "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py)."
    "For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py) piece of code."
    ]
    },
    {
    @@ -383,7 +366,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "The output was too long and I stored it in [this file](github.link)."
    "The output was too long, so I stored it in [this file](github.link)."
    ]
    },
    {
    @@ -406,7 +389,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "### Load stored model to generate some more jokes"
    "### Load stored model to generate more jokes"
    ]
    },
    {
    @@ -417,13 +400,6 @@
    "source": [
    "model.load_state_dict(torch.load(\"gpt2_medium_joker.pt\"))"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
  10. @mf1024 mf1024 revised this gist Nov 22, 2019. 1 changed file with 49 additions and 4022 deletions.
    4,071 changes: 49 additions & 4,022 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    49 additions, 4,022 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
  11. @mf1024 mf1024 created this gist Nov 21, 2019.
    4,423 changes: 4,423 additions & 0 deletions Fine-tuning GPT2-medium in PyTorch.ipynb
    4,423 additions, 0 deletions not shown because the diff is too large. Please use a local Git client to view these changes.