Skip to content

Instantly share code, notes, and snippets.

@metric-space
Created May 19, 2023 22:22
Show Gist options
  • Select an option

  • Save metric-space/a315a97d581039cd47888c24caba9f65 to your computer and use it in GitHub Desktop.

Select an option

Save metric-space/a315a97d581039cd47888c24caba9f65 to your computer and use it in GitHub Desktop.

Revisions

  1. metric-space created this gist May 19, 2023.
    252 changes: 252 additions & 0 deletions example_usage.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,252 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "colab": {
    "provenance": []
    },
    "kernelspec": {
    "name": "python3",
    "display_name": "Python 3"
    },
    "language_info": {
    "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
    },
    "cells": [
    {
    "cell_type": "markdown",
    "source": [
    "\n",
    "**Best-of-n sampling class usage**\n",
    "\n"
    ],
    "metadata": {
    "id": "WQpNapZNWuXP"
    }
    },
    {
    "cell_type": "markdown",
    "source": [
    "Import dependencies\n"
    ],
    "metadata": {
    "id": "Lo98lkdP66_x"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "%pip install torch datasets transformers git+https://github.com/metric-space/trl.git@140/best-of-n-sampling-class"
    ],
    "metadata": {
    "id": "vDA6qayz692w"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "import torch\n",
    "import pandas as pd\n",
    "from transformers import pipeline, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "from trl import AutoModelForCausalLMWithValueHead\n",
    "from trl.core import LengthSampler\n",
    "from trl.extras import BestOfNSampler\n",
    "\n",
    "device = 0 if torch.cuda.is_available() else \"cpu\" "
    ],
    "metadata": {
    "id": "M1s_iNm773hM"
    },
    "execution_count": 2,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "Various constants"
    ],
    "metadata": {
    "id": "Y7hyrIrO8tcY"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "ref_model_name = 'lvwerra/gpt2-imdb'\n",
    "model_name = 'lvwerra/gpt2-imdb-pos-v2'\n",
    "reward_model = 'lvwerra/distilbert-imdb'\n",
    " \n",
    "N_BEST_OF = 4"
    ],
    "metadata": {
    "id": "MqS3OM6Q8x6g"
    },
    "execution_count": 3,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "Models and tokenizers "
    ],
    "metadata": {
    "id": "c1YcXeElg6or"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "\n",
    "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
    "\n",
    "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# cuda-ize models\n",
    "ref_model.cuda()"
    ],
    "metadata": {
    "id": "b855NrL181Hh"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "Dataset building"
    ],
    "metadata": {
    "id": "Z1Cz0gCFhZYJ"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "def build_dataset(tokenizer, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n",
    " # load imdb with datasets\n",
    " ds = load_dataset(dataset_name, split=\"train\")\n",
    " ds = ds.rename_columns({\"text\": \"review\"})\n",
    " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
    "\n",
    " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
    "\n",
    " def tokenize(sample):\n",
    " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
    " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
    " return sample\n",
    "\n",
    " ds = ds.map(tokenize, batched=False)\n",
    " ds.set_format(type=\"torch\")\n",
    " return ds\n",
    "\n",
    "dataset = build_dataset(tokenizer)"
    ],
    "metadata": {
    "id": "LqLVEp5p_8XM"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n",
    "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
    ],
    "metadata": {
    "id": "AqA2McjMAxNw"
    },
    "execution_count": 6,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "\n",
    "output_min_length = 4\n",
    "output_max_length = 16\n",
    "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
    "\n",
    "#### get a batch from the dataset\n",
    "bs = 16\n",
    "output_data = dict()\n",
    "dataset.set_format(\"pandas\")\n",
    "df_batch = dataset[:].sample(bs)\n",
    "output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
    "query_tensors = df_batch[\"input_ids\"].tolist()\n",
    "\n",
    "# :: [Resp]\n",
    "response_tensors_ref, response_tensors = [], []\n",
    "# :: [[Resp]]\n",
    "response_tensors_best_of = []\n",
    "\n"
    ],
    "metadata": {
    "id": "L_q4qs35AxcR"
    },
    "execution_count": 7,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "a = BestOfNSampler(ref_model, tokenizer, reward_pipe, reward_kwargs=sent_kwargs, length_sampler=output_length_sampler)\n",
    "a.generate(query_tensors, device=device, **gen_kwargs)"
    ],
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "wDv5wz5DiTw4",
    "outputId": "d95e4bcc-fccd-4102-8934-768628ab9975"
    },
    "execution_count": 8,
    "outputs": [
    {
    "output_type": "stream",
    "name": "stderr",
    "text": [
    "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py:1080: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
    " warnings.warn(\n"
    ]
    },
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "['I rented this film purely on the premise of looking at an aspiring actress who is very',\n",
    " 'An independent feature can now be seen via premium cinema where the ever as good Alice and Richard',\n",
    " 'When I saw this movie, I really enjoyed it',\n",
    " 'This movie has an all-time high for movie retellings, comedy that seems',\n",
    " 'I first saw this film about 7 years ago. I was amazed about Finney',\n",
    " 'A previous reviewer said the film has an ugly ending, but through the tender moments',\n",
    " 'To make any film a buddy to. When the script is too hack',\n",
    " 'A recent post here by Michael Tuul-Jung suggests that Newton did this',\n",
    " 'Though the award-winning doc',\n",
    " 'Steven Seagal, was in Zombie after World War II? And',\n",
    " '\"Plants recall something very different this time, living in the',\n",
    " \"This is the only movie I've seen that show him that deserves better than 1/10- with the ending\",\n",
    " 'I saw Chan Is Missing when it came out several years ago.<',\n",
    " 'Not that I want to be mean anymore. I felt like',\n",
    " \"I'm sure that rented out all but a\",\n",
    " 'If you want to know some more stories you can']"
    ]
    },
    "metadata": {},
    "execution_count": 8
    }
    ]
    }
    ]
    }