{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "cells": [ { "cell_type": "markdown", "source": [ "\n", "**Best-of-n sampling class usage**\n", "\n" ], "metadata": { "id": "WQpNapZNWuXP" } }, { "cell_type": "markdown", "source": [ "Import dependencies\n" ], "metadata": { "id": "Lo98lkdP66_x" } }, { "cell_type": "code", "source": [ "%pip install torch datasets transformers git+https://github.com/metric-space/trl.git@140/best-of-n-sampling-class" ], "metadata": { "id": "vDA6qayz692w" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "import pandas as pd\n", "from transformers import pipeline, AutoTokenizer\n", "from datasets import load_dataset\n", "\n", "from trl import AutoModelForCausalLMWithValueHead\n", "from trl.core import LengthSampler\n", "from trl.extras import BestOfNSampler\n", "\n", "device = 0 if torch.cuda.is_available() else \"cpu\" " ], "metadata": { "id": "M1s_iNm773hM" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "Various constants" ], "metadata": { "id": "Y7hyrIrO8tcY" } }, { "cell_type": "code", "source": [ "ref_model_name = 'lvwerra/gpt2-imdb'\n", "model_name = 'lvwerra/gpt2-imdb-pos-v2'\n", "reward_model = 'lvwerra/distilbert-imdb'\n", " \n", "N_BEST_OF = 4" ], "metadata": { "id": "MqS3OM6Q8x6g" }, "execution_count": 3, "outputs": [] }, { "cell_type": "markdown", "source": [ "Models and tokenizers " ], "metadata": { "id": "c1YcXeElg6or" } }, { "cell_type": "code", "source": [ "\n", "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", "\n", "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "# cuda-ize models\n", "ref_model.cuda()" ], "metadata": { "id": "b855NrL181Hh" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Dataset building" ], "metadata": { "id": "Z1Cz0gCFhZYJ" } }, { "cell_type": "code", "source": [ "def build_dataset(tokenizer, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n", " # load imdb with datasets\n", " ds = load_dataset(dataset_name, split=\"train\")\n", " ds = ds.rename_columns({\"text\": \"review\"})\n", " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", "\n", " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", "\n", " def tokenize(sample):\n", " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", " return sample\n", "\n", " ds = ds.map(tokenize, batched=False)\n", " ds.set_format(type=\"torch\")\n", " return ds\n", "\n", "dataset = build_dataset(tokenizer)" ], "metadata": { "id": "LqLVEp5p_8XM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n", "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" ], "metadata": { "id": "AqA2McjMAxNw" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "output_min_length = 4\n", "output_max_length = 16\n", "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", "\n", "#### get a batch from the dataset\n", "bs = 16\n", "output_data = dict()\n", "dataset.set_format(\"pandas\")\n", "df_batch = dataset[:].sample(bs)\n", "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", "query_tensors = df_batch[\"input_ids\"].tolist()\n", "\n", "# :: [Resp]\n", "response_tensors_ref, response_tensors = [], []\n", "# :: [[Resp]]\n", "response_tensors_best_of = []\n", "\n" ], "metadata": { "id": "L_q4qs35AxcR" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "a = BestOfNSampler(ref_model, tokenizer, reward_pipe, reward_kwargs=sent_kwargs, length_sampler=output_length_sampler)\n", "a.generate(query_tensors, device=device, **gen_kwargs)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wDv5wz5DiTw4", "outputId": "d95e4bcc-fccd-4102-8934-768628ab9975" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py:1080: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", " warnings.warn(\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "['I rented this film purely on the premise of looking at an aspiring actress who is very',\n", " 'An independent feature can now be seen via premium cinema where the ever as good Alice and Richard',\n", " 'When I saw this movie, I really enjoyed it',\n", " 'This movie has an all-time high for movie retellings, comedy that seems',\n", " 'I first saw this film about 7 years ago. I was amazed about Finney',\n", " 'A previous reviewer said the film has an ugly ending, but through the tender moments',\n", " 'To make any film a buddy to. When the script is too hack',\n", " 'A recent post here by Michael Tuul-Jung suggests that Newton did this',\n", " 'Though the award-winning doc',\n", " 'Steven Seagal, was in Zombie after World War II? And',\n", " '\"Plants recall something very different this time, living in the',\n", " \"This is the only movie I've seen that show him that deserves better than 1/10- with the ending\",\n", " 'I saw Chan Is Missing when it came out several years ago.<',\n", " 'Not that I want to be mean anymore. I felt like',\n", " \"I'm sure that rented out all but a\",\n", " 'If you want to know some more stories you can']" ] }, "metadata": {}, "execution_count": 8 } ] } ] }