<a href="https://colab.research.google.com/gist/qunash/820c86d1d267ec8051d9f68b4f4bb656/grpo_qwen-0-5b_single_t4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Full **GRPO** fine-tuning `Qwen2.5 0.5B` on a single T4
This colab uses a lot of tweaks and tricks to make GRPO **full fine-tuning** Qwen2.5-0.5-Instruct fit on a single T4 GPU, so that it could be run in a free Google Colab.

It uses VLLm for fast inference and does not make compromises on batch and completion group sizes.

With this setup you can improve `Qwen2.5-0.5B-Instruct`'s gsm8k eval result from 22.4% to 48.6% in just \~150 steps (~30 minutes) on a single T4 GPU.


</br>

---

Here are some important optimizations used:

* A [fork](https://github.com/andyl98/trl/tree/grpo-vram-optimization) of the TRL repo by [andyl98](https://github.com/andyl98), which introduces batched logprobs calculation. I then forked this fork and further optimized the logprobs computation function to reduce VRAM usage.
* 8-bit AdamW optimizer
* Set explicit memory allocation limits with `PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:128'`

</br>

---

If using Ampere, or later architecture nvidia GPU, you can further reduce VRAM usage by:


*   enabling `attn_implementation="flash_attention_2"` during model loading
*   loading the model with [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) wrapper:

      ```Python
      from liger_kernel.transformers import AutoLigerKernelForCausalLM
      model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
      ```

[![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fgist.github.com%2Fqunash%2F820c86d1d267ec8051d9f68b4f4bb656&label=views&countColor=%23263759)](https://visitorbadge.io/status?path=https%3A%2F%2Fgist.github.com%2Fqunash%2F820c86d1d267ec8051d9f68b4f4bb656)

In [None]:
%%capture
!pip install uv
!uv pip install --system git+https://github.com/qunash/trl-1.git@grpo-vram-optimization
!uv pip install --system triton==2.2.0
!uv pip install --system vllm
!uv pip install --system bitsandbytes

In [None]:
import os
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl.trainer import GRPOConfig, GRPOTrainer


R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."


def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
    dataset = load_dataset(dataset_name, 'main')[split]

    def extract_hash_answer(text: str) -> str | None:
        try:
            return text.split("####")[1].strip()
        except IndexError:
            return None

    def process_batch(batch):
        prompts = [[
            {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
            {'role': 'user', 'content': "What is 2+2?"},
            {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"},
            {'role': 'user', 'content': q.strip()}
        ] for q in batch['question']]

        return {
            'prompt': prompts,
            'answer': [extract_hash_answer(a) for a in batch['answer']]
        }

    return dataset.map(process_batch, batched=True, batch_size=chunk_size)

dataset_name = 'openai/gsm8k'
dataset = preprocess_dataset(dataset_name, chunk_size=500)


def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

# reward functions
# VALID_FORMAT = re.compile(r"<reasoning>(?:(?!</?reasoning>|</?answer>).)*</reasoning>\n<answer>(?:(?!</?reasoning>|</?answer>).)*</answer>")

# def format_reward_func(completions, **kwargs) -> list[float]:
#     """Reward function that checks if the completion has the correct format."""
#     responses = [completion[0]["content"] for completion in completions]
#     matches = [bool(VALID_FORMAT.fullmatch(r.strip())) for r in responses]
#     return [1.0 if match else 0.0 for match in matches]

def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has the correct format."""
    pattern = r"^<reasoning>(?:(?!</reasoning>).)*</reasoning>\n<answer>(?:(?!</answer>).)*</answer>$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [bool(re.match(pattern, r)) for r in responses]
    return [1.0 if match else 0.0 for match in matches]

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that checks if the answer is correct."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print(f"Question: {prompts[0][-1]['content']}\nAnswer: {answer[0]}\nResponse: {responses[0]}\nExtracted: {extracted_responses[0]}")
    print(''.join('✅' if r == a else '❌' for r, a in zip(extracted_responses, answer)))
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# model_name = "Qwen/Qwen2.5-0.5B"
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir = f"outputs/{model_name.split('/')[-1]}-GRPO"
run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}"


# Set memory-related environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

max_prompt_length=256
max_completion_length=512

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=1e-5,
    beta=0.005, # divergence coefficient – how much the policy is allowed to deviate from the reference model. higher value – more conservative updates. Default is 0.04
    optim="adamw_8bit",
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,
    num_generations=4,  # group size
    gradient_accumulation_steps=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
    use_vllm=True,
    vllm_init_kwargs={
        "device": "cuda:0",
        "gpu_memory_utilization": 0.3,
        "max_model_len": max_prompt_length + max_completion_length,
        "dtype": "half",
        # "enable_chunked_prefill": True,
        # "max_num_batched_tokens": 2048,
    },
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    logit_computation_mini_batch_size=1,
    enable_profiling=False
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2", # T4 is not supported
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    model_max_length=training_args.max_completion_length,
)
tokenizer.pad_token = tokenizer.eos_token

# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        format_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()


# Eval

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm.notebook import tqdm
import numpy as np
from typing import List, Dict
import json
from datetime import datetime
import logging

# Disable VLLM's progress bars
logging.getLogger("vllm").setLevel(logging.WARNING)

# Constants from training script
R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

def extract_hash_answer(text: str) -> str | None:
    try:
        return text.split("####")[1].strip()
    except IndexError:
        return None

def evaluate_model(
    model_path: str,
    batch_size: int = 4,
    num_samples: int = None,
    save_results: bool = True,
    gpu_memory_utilization: float = 0.3,
) -> Dict:
    print("Initializing evaluation...")

    # Initialize VLLM with progress indicator
    with tqdm(total=2, desc="Loading model components") as pbar:
        llm = LLM(
            model=model_path,
            dtype="half",
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=768,
            device="cuda:0",
            enable_chunked_prefill=True,
        )
        pbar.update(1)

        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            model_max_length=768,
            padding_side='right',
            truncation_side='right'
        )
        pbar.update(1)

    # Set up sampling parameters
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=512,  # Matching max_completion_length from training
        stop_token_ids=[tokenizer.eos_token_id],
    )

    # Load test dataset
    print("Loading dataset...")
    dataset = load_dataset('openai/gsm8k', 'main', split='test')
    if num_samples:
        dataset = dataset.select(range(num_samples))
    total_samples = len(dataset)
    print(f"Loaded {total_samples} samples")

    results = []
    correct = 0
    total = 0

    # Create progress bar
    progress_bar = tqdm(
        total=total_samples,
        desc="Processing samples",
        unit="examples",
        dynamic_ncols=True,
    )

    progress_bar.set_postfix({
        'acc': '0.00%',
        'correct': '0',
    })

    # Process in batches
    for i in range(0, total_samples, batch_size):
        batch_data = dataset[i:i + batch_size]
        current_batch_size = len(batch_data['question'])

        # Prepare prompts using same format as training
        prompts = [
            [
                {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
                {'role': 'user', 'content': "What is 2+2?"},
                {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"},
                {'role': 'user', 'content': q.strip()}
            ] for q in batch_data['question']
        ]

        # Convert to chat format
        formatted_prompts = [
            tokenizer.apply_chat_template(
                p,
                tokenize=False,
                add_generation_prompt=True
            )
            for p in prompts
        ]

        # Generate responses
        outputs = llm.generate(
            formatted_prompts,
            sampling_params,
        )

        # Process responses
        for j, output in enumerate(outputs):
            response = output.outputs[0].text

            # Extract answers
            generated_answer = extract_xml_answer(response)
            true_answer = extract_hash_answer(batch_data['answer'][j])

            # Store result
            result = {
                'question': batch_data['question'][j],
                'true_answer': true_answer,
                'generated_answer': generated_answer,
                'full_response': response,
                'correct': generated_answer == true_answer
            }
            results.append(result)

            # Update metrics
            if generated_answer == true_answer:
                correct += 1
            total += 1

        # Update progress
        progress_bar.update(current_batch_size)
        progress_bar.set_postfix({
            'acc': f'{(correct/total)*100:.2f}%',
            'correct': f'{correct}/{total}',
        })

    progress_bar.close()

    # Calculate metrics
    accuracy = correct / total if total > 0 else 0
    metrics = {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'model_path': model_path,
        'timestamp': datetime.now().isoformat()
    }

    # Save results
    if save_results:
        save_path = f"gsm8k_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(save_path, 'w') as f:
            json.dump({
                'metrics': metrics,
                'results': results
            }, f, indent=2)
        print(f"\nResults saved to {save_path}")

    return metrics

print("Starting GSM8K evaluation...")
checkpoint_path = "outputs/Qwen2.5-0.5B-Instruct-GRPO/checkpoint-latest"  # Update path as needed

metrics = evaluate_model(
    model_path=checkpoint_path,
    batch_size=4,
    num_samples=None,
    save_results=True,
    gpu_memory_utilization=0.3,
)

print("\nFinal Evaluation Results:")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Correct: {metrics['correct']}/{metrics['total']}")