@@ -0,0 +1,456 @@
{
"nbformat" : 4 ,
"nbformat_minor" : 0 ,
"metadata" : {
"colab" : {
"provenance" : [],
"gpuType" : " T4" ,
"authorship_tag" : " ABX9TyPefpSRKdN8EShzWByzf5mC"
},
"kernelspec" : {
"name" : " python3" ,
"display_name" : " Python 3"
},
"language_info" : {
"name" : " python"
},
"accelerator" : " GPU"
},
"cells" : [
{
"cell_type" : " markdown" ,
"source" : [
" # Full **GRPO** fine-tuning `Qwen2.5 0.5B` on a single T4\n " ,
" This colab uses a lot of tweaks and tricks to make GRPO **full fine-tuning** Qwen2.5-0.5-Instruct fit on a single T4 GPU, so that it could be run in a free Google Colab.\n " ,
" \n " ,
" It uses VLLm for fast inference and does not make compromises on batch and completion group sizes.\n " ,
" \n " ,
" With this setup you can improve `Qwen2.5-0.5B-Instruct`'s gsm8k eval result from 22.4% to 48.6% in just \\ ~150 steps (~30 minutes) on a single T4 GPU.\n " ,
" \n " ,
" \n " ,
" </br>\n " ,
" \n " ,
" ---\n " ,
" \n " ,
" Here are some important optimizations used:\n " ,
" \n " ,
" * A [fork](https://github.com/andyl98/trl/tree/grpo-vram-optimization) of the TRL repo by [andyl98](https://github.com/andyl98), which introduces batched logprobs calculation. I then forked this fork and further optimized the logprobs computation function to reduce VRAM usage.\n " ,
" * 8-bit AdamW optimizer\n " ,
" * Set explicit memory allocation limits with `PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:128'`\n " ,
" \n " ,
" </br>\n " ,
" \n " ,
" ---\n " ,
" \n " ,
" If using Ampere, or later architecture nvidia GPU, you can further reduce VRAM usage by:\n " ,
" \n " ,
" \n " ,
" * enabling `attn_implementation=\" flash_attention_2\" ` during model loading\n " ,
" * loading the model with [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) wrapper:\n " ,
" \n " ,
" ```Python\n " ,
" from liger_kernel.transformers import AutoLigerKernelForCausalLM\n " ,
" model = AutoLigerKernelForCausalLM.from_pretrained(\" path/to/some/model\" )\n " ,
" ```"
],
"metadata" : {
"id" : " 8oW2D1_PpNqF"
}
},
{
"cell_type" : " code" ,
"source" : [
" %%capture\n " ,
" !pip install uv\n " ,
" !uv pip install --system git+https://github.com/qunash/trl-1.git@grpo-vram-optimization\n " ,
" !uv pip install --system triton==2.2.0\n " ,
" !uv pip install --system vllm\n " ,
" !uv pip install --system bitsandbytes"
],
"metadata" : {
"id" : " znbQSsMqi7HJ"
},
"execution_count" : 1 ,
"outputs" : []
},
{
"cell_type" : " code" ,
"source" : [
" import os\n " ,
" import re\n " ,
" import torch\n " ,
" from datasets import load_dataset, Dataset\n " ,
" from transformers import AutoTokenizer, AutoModelForCausalLM\n " ,
" from trl.trainer import GRPOConfig, GRPOTrainer\n " ,
" \n " ,
" \n " ,
" R1_STYLE_SYSTEM_PROMPT = \"\"\" A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n " ,
" The assistant first thinks about the reasoning process in the mind and then provides the user\n " ,
" with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and\n " ,
" <answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n " ,
" <answer> answer here </answer>.\"\"\"\n " ,
" \n " ,
" TASK_SPECIFIC_INSTRUCTIONS = \" The answer must be a single integer.\"\n " ,
" \n " ,
" \n " ,
" def preprocess_dataset(dataset_name, split=\" train\" , chunk_size=1000) -> Dataset:\n " ,
" dataset = load_dataset(dataset_name, 'main')[split]\n " ,
" \n " ,
" def extract_hash_answer(text: str) -> str | None:\n " ,
" try:\n " ,
" return text.split(\" ####\" )[1].strip()\n " ,
" except IndexError:\n " ,
" return None\n " ,
" \n " ,
" def process_batch(batch):\n " ,
" prompts = [[\n " ,
" {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + \"\\ n\" + TASK_SPECIFIC_INSTRUCTIONS},\n " ,
" {'role': 'user', 'content': \" What is 2+2?\" },\n " ,
" {'role': 'assistant', 'content': \" <reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\\ n<answer>4</answer>\" },\n " ,
" {'role': 'user', 'content': q.strip()}\n " ,
" ] for q in batch['question']]\n " ,
" \n " ,
" return {\n " ,
" 'prompt': prompts,\n " ,
" 'answer': [extract_hash_answer(a) for a in batch['answer']]\n " ,
" }\n " ,
" \n " ,
" return dataset.map(process_batch, batched=True, batch_size=chunk_size)\n " ,
" \n " ,
" dataset_name = 'openai/gsm8k'\n " ,
" dataset = preprocess_dataset(dataset_name, chunk_size=500)\n " ,
" \n " ,
" \n " ,
" def extract_xml_answer(text: str) -> str:\n " ,
" try:\n " ,
" answer = text.split(\" <answer>\" )[-1].split(\" </answer>\" )[0].strip()\n " ,
" return answer\n " ,
" except IndexError:\n " ,
" return \"\"\n " ,
" \n " ,
" # reward functions\n " ,
" # VALID_FORMAT = re.compile(r\" <reasoning>(?:(?!</?reasoning>|</?answer>).)*</reasoning>\\ n<answer>(?:(?!</?reasoning>|</?answer>).)*</answer>\" )\n " ,
" \n " ,
" # def format_reward_func(completions, **kwargs) -> list[float]:\n " ,
" # \"\"\" Reward function that checks if the completion has the correct format.\"\"\"\n " ,
" # responses = [completion[0][\" content\" ] for completion in completions]\n " ,
" # matches = [bool(VALID_FORMAT.fullmatch(r.strip())) for r in responses]\n " ,
" # return [1.0 if match else 0.0 for match in matches]\n " ,
" \n " ,
" def format_reward_func(completions, **kwargs) -> list[float]:\n " ,
" \"\"\" Reward function that checks if the completion has the correct format.\"\"\"\n " ,
" pattern = r\" ^<reasoning>.*?</reasoning>\\ s*<answer>.*?</answer>$\"\n " ,
" responses = [completion[0][\" content\" ] for completion in completions]\n " ,
" matches = [bool(re.match(pattern, r)) for r in responses]\n " ,
" return [1.0 if match else 0.0 for match in matches]\n " ,
" \n " ,
" def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n " ,
" \"\"\" Reward function that checks if the answer is correct.\"\"\"\n " ,
" responses = [completion[0]['content'] for completion in completions]\n " ,
" extracted_responses = [extract_xml_answer(r) for r in responses]\n " ,
" print(f\" Question: {prompts[0][-1]['content']}\\ nAnswer: {answer[0]}\\ nResponse: {responses[0]}\\ nExtracted: {extracted_responses[0]}\" )\n " ,
" print(''.join('✅' if r == a else '❌' for r, a in zip(extracted_responses, answer)))\n " ,
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n " ,
" \n " ,
" # model_name = \" Qwen/Qwen2.5-0.5B\"\n " ,
" model_name = \" Qwen/Qwen2.5-0.5B-Instruct\"\n " ,
" \n " ,
" output_dir = f\" outputs/{model_name.split('/')[-1]}-GRPO\"\n " ,
" run_name = f\" {model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}\"\n " ,
" \n " ,
" \n " ,
" # Set memory-related environment variables\n " ,
" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'\n " ,
" \n " ,
" max_prompt_length=256\n " ,
" max_completion_length=512\n " ,
" \n " ,
" training_args = GRPOConfig(\n " ,
" output_dir=output_dir,\n " ,
" run_name=run_name,\n " ,
" learning_rate=1e-5,\n " ,
" beta=0.005, # divergence coefficient – how much the policy is allowed to deviate from the reference model. higher value – more conservative updates. Default is 0.04\n " ,
" optim=\" adamw_8bit\" ,\n " ,
" adam_beta1=0.9,\n " ,
" adam_beta2=0.99,\n " ,
" weight_decay=0.1,\n " ,
" warmup_ratio=0.1,\n " ,
" lr_scheduler_type='cosine',\n " ,
" logging_steps=1,\n " ,
" bf16=True,\n " ,
" per_device_train_batch_size=4,\n " ,
" num_generations=4, # group size\n " ,
" gradient_accumulation_steps=4,\n " ,
" max_prompt_length=max_prompt_length,\n " ,
" max_completion_length=max_completion_length,\n " ,
" num_train_epochs=1,\n " ,
" save_steps=100,\n " ,
" max_grad_norm=0.1,\n " ,
" report_to=\" wandb\" ,\n " ,
" log_on_each_node=False,\n " ,
" use_vllm=True,\n " ,
" vllm_init_kwargs={\n " ,
" \" device\" : \" cuda:0\" ,\n " ,
" \" gpu_memory_utilization\" : 0.3,\n " ,
" \" max_model_len\" : max_prompt_length + max_completion_length,\n " ,
" \" dtype\" : \" half\" ,\n " ,
" # \" enable_chunked_prefill\" : True,\n " ,
" # \" max_num_batched_tokens\" : 2048,\n " ,
" },\n " ,
" gradient_checkpointing=True,\n " ,
" gradient_checkpointing_kwargs={\" use_reentrant\" : False},\n " ,
" logit_computation_mini_batch_size=1,\n " ,
" enable_profiling=False\n " ,
" )\n " ,
" \n " ,
" # Load model\n " ,
" model = AutoModelForCausalLM.from_pretrained(\n " ,
" model_name,\n " ,
" torch_dtype=torch.bfloat16,\n " ,
" # attn_implementation=\" flash_attention_2\" , # T4 is not supported\n " ,
" device_map=\" auto\" ,\n " ,
" )\n " ,
" \n " ,
" tokenizer = AutoTokenizer.from_pretrained(\n " ,
" model_name,\n " ,
" model_max_length=training_args.max_completion_length,\n " ,
" )\n " ,
" tokenizer.pad_token = tokenizer.eos_token\n " ,
" \n " ,
" # Initialize trainer\n " ,
" trainer = GRPOTrainer(\n " ,
" model=model,\n " ,
" processing_class=tokenizer,\n " ,
" reward_funcs=[\n " ,
" format_reward_func,\n " ,
" correctness_reward_func\n " ,
" ],\n " ,
" args=training_args,\n " ,
" train_dataset=dataset,\n " ,
" )\n " ,
" \n " ,
" trainer.train()\n "
],
"metadata" : {
"id" : " RnACYvTBWA1q"
},
"execution_count" : null ,
"outputs" : []
},
{
"cell_type" : " markdown" ,
"source" : [
" # Eval"
],
"metadata" : {
"id" : " OFRHLvWhzHn0"
}
},
{
"cell_type" : " code" ,
"source" : [
" import torch\n " ,
" from datasets import load_dataset\n " ,
" from transformers import AutoTokenizer\n " ,
" from vllm import LLM, SamplingParams\n " ,
" from tqdm.notebook import tqdm\n " ,
" import numpy as np\n " ,
" from typing import List, Dict\n " ,
" import json\n " ,
" from datetime import datetime\n " ,
" import logging\n " ,
" \n " ,
" # Disable VLLM's progress bars\n " ,
" logging.getLogger(\" vllm\" ).setLevel(logging.WARNING)\n " ,
" \n " ,
" # Constants from training script\n " ,
" R1_STYLE_SYSTEM_PROMPT = \"\"\" A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n " ,
" The assistant first thinks about the reasoning process in the mind and then provides the user\n " ,
" with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and\n " ,
" <answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n " ,
" <answer> answer here </answer>.\"\"\"\n " ,
" \n " ,
" TASK_SPECIFIC_INSTRUCTIONS = \" The answer must be a single integer.\"\n " ,
" \n " ,
" def extract_xml_answer(text: str) -> str:\n " ,
" try:\n " ,
" answer = text.split(\" <answer>\" )[-1].split(\" </answer>\" )[0].strip()\n " ,
" return answer\n " ,
" except IndexError:\n " ,
" return \"\"\n " ,
" \n " ,
" def extract_hash_answer(text: str) -> str | None:\n " ,
" try:\n " ,
" return text.split(\" ####\" )[1].strip()\n " ,
" except IndexError:\n " ,
" return None\n " ,
" \n " ,
" def evaluate_model(\n " ,
" model_path: str,\n " ,
" batch_size: int = 4,\n " ,
" num_samples: int = None,\n " ,
" save_results: bool = True,\n " ,
" gpu_memory_utilization: float = 0.3,\n " ,
" ) -> Dict:\n " ,
" print(\" Initializing evaluation...\" )\n " ,
" \n " ,
" # Initialize VLLM with progress indicator\n " ,
" with tqdm(total=2, desc=\" Loading model components\" ) as pbar:\n " ,
" llm = LLM(\n " ,
" model=model_path,\n " ,
" dtype=\" half\" ,\n " ,
" gpu_memory_utilization=gpu_memory_utilization,\n " ,
" max_model_len=768,\n " ,
" device=\" cuda:0\" ,\n " ,
" enable_chunked_prefill=True,\n " ,
" )\n " ,
" pbar.update(1)\n " ,
" \n " ,
" tokenizer = AutoTokenizer.from_pretrained(\n " ,
" model_path,\n " ,
" model_max_length=768,\n " ,
" padding_side='right',\n " ,
" truncation_side='right'\n " ,
" )\n " ,
" pbar.update(1)\n " ,
" \n " ,
" # Set up sampling parameters\n " ,
" sampling_params = SamplingParams(\n " ,
" temperature=0.0,\n " ,
" max_tokens=512, # Matching max_completion_length from training\n " ,
" stop_token_ids=[tokenizer.eos_token_id],\n " ,
" )\n " ,
" \n " ,
" # Load test dataset\n " ,
" print(\" Loading dataset...\" )\n " ,
" dataset = load_dataset('openai/gsm8k', 'main', split='test')\n " ,
" if num_samples:\n " ,
" dataset = dataset.select(range(num_samples))\n " ,
" total_samples = len(dataset)\n " ,
" print(f\" Loaded {total_samples} samples\" )\n " ,
" \n " ,
" results = []\n " ,
" correct = 0\n " ,
" total = 0\n " ,
" \n " ,
" # Create progress bar\n " ,
" progress_bar = tqdm(\n " ,
" total=total_samples,\n " ,
" desc=\" Processing samples\" ,\n " ,
" unit=\" examples\" ,\n " ,
" dynamic_ncols=True,\n " ,
" )\n " ,
" \n " ,
" progress_bar.set_postfix({\n " ,
" 'acc': '0.00%',\n " ,
" 'correct': '0',\n " ,
" })\n " ,
" \n " ,
" # Process in batches\n " ,
" for i in range(0, total_samples, batch_size):\n " ,
" batch_data = dataset[i:i + batch_size]\n " ,
" current_batch_size = len(batch_data['question'])\n " ,
" \n " ,
" # Prepare prompts using same format as training\n " ,
" prompts = [\n " ,
" [\n " ,
" {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + \"\\ n\" + TASK_SPECIFIC_INSTRUCTIONS},\n " ,
" {'role': 'user', 'content': \" What is 2+2?\" },\n " ,
" {'role': 'assistant', 'content': \" <reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\\ n<answer>4</answer>\" },\n " ,
" {'role': 'user', 'content': q.strip()}\n " ,
" ] for q in batch_data['question']\n " ,
" ]\n " ,
" \n " ,
" # Convert to chat format\n " ,
" formatted_prompts = [\n " ,
" tokenizer.apply_chat_template(\n " ,
" p,\n " ,
" tokenize=False,\n " ,
" add_generation_prompt=True\n " ,
" )\n " ,
" for p in prompts\n " ,
" ]\n " ,
" \n " ,
" # Generate responses\n " ,
" outputs = llm.generate(\n " ,
" formatted_prompts,\n " ,
" sampling_params,\n " ,
" )\n " ,
" \n " ,
" # Process responses\n " ,
" for j, output in enumerate(outputs):\n " ,
" response = output.outputs[0].text\n " ,
" \n " ,
" # Extract answers\n " ,
" generated_answer = extract_xml_answer(response)\n " ,
" true_answer = extract_hash_answer(batch_data['answer'][j])\n " ,
" \n " ,
" # Store result\n " ,
" result = {\n " ,
" 'question': batch_data['question'][j],\n " ,
" 'true_answer': true_answer,\n " ,
" 'generated_answer': generated_answer,\n " ,
" 'full_response': response,\n " ,
" 'correct': generated_answer == true_answer\n " ,
" }\n " ,
" results.append(result)\n " ,
" \n " ,
" # Update metrics\n " ,
" if generated_answer == true_answer:\n " ,
" correct += 1\n " ,
" total += 1\n " ,
" \n " ,
" # Update progress\n " ,
" progress_bar.update(current_batch_size)\n " ,
" progress_bar.set_postfix({\n " ,
" 'acc': f'{(correct/total)*100:.2f}%',\n " ,
" 'correct': f'{correct}/{total}',\n " ,
" })\n " ,
" \n " ,
" progress_bar.close()\n " ,
" \n " ,
" # Calculate metrics\n " ,
" accuracy = correct / total if total > 0 else 0\n " ,
" metrics = {\n " ,
" 'accuracy': accuracy,\n " ,
" 'correct': correct,\n " ,
" 'total': total,\n " ,
" 'model_path': model_path,\n " ,
" 'timestamp': datetime.now().isoformat()\n " ,
" }\n " ,
" \n " ,
" # Save results\n " ,
" if save_results:\n " ,
" save_path = f\" gsm8k_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n " ,
" with open(save_path, 'w') as f:\n " ,
" json.dump({\n " ,
" 'metrics': metrics,\n " ,
" 'results': results\n " ,
" }, f, indent=2)\n " ,
" print(f\"\\ nResults saved to {save_path}\" )\n " ,
" \n " ,
" return metrics\n " ,
" \n " ,
" print(\" Starting GSM8K evaluation...\" )\n " ,
" checkpoint_path = \" outputs/Qwen2.5-0.5B-Instruct-GRPO/checkpoint-latest\" # Update path as needed\n " ,
" \n " ,
" metrics = evaluate_model(\n " ,
" model_path=checkpoint_path,\n " ,
" batch_size=4,\n " ,
" num_samples=None,\n " ,
" save_results=True,\n " ,
" gpu_memory_utilization=0.3,\n " ,
" )\n " ,
" \n " ,
" print(\"\\ nFinal Evaluation Results:\" )\n " ,
" print(f\" Accuracy: {metrics['accuracy']:.2%}\" )\n " ,
" print(f\" Correct: {metrics['correct']}/{metrics['total']}\" )"
],
"metadata" : {
"id" : " nW6pJMSDD2sv"
},
"execution_count" : null ,
"outputs" : []
}
]
}