Skip to content

Instantly share code, notes, and snippets.

@jarrelscy
Created February 11, 2025 23:17
Show Gist options
  • Save jarrelscy/c35bccf971b08923d245cbf616554c79 to your computer and use it in GitHub Desktop.
Save jarrelscy/c35bccf971b08923d245cbf616554c79 to your computer and use it in GitHub Desktop.

Revisions

  1. jarrelscy renamed this gist Feb 11, 2025. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. jarrelscy created this gist Feb 11, 2025.
    167 changes: 167 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,167 @@
    from unsloth import FastLanguageModel, PatchFastRL
    PatchFastRL("GRPO", FastLanguageModel)

    from unsloth import is_bfloat16_supported
    import torch
    import wandb # Import wandb for logging

    # Initialize wandb with your project name
    wandb.init(project="unsloth-grpo")

    max_seq_length = 1024 # Can increase for longer reasoning traces
    lora_rank = 64 # Larger rank = smarter, but slower

    model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=True, # False for LoRA 16bit
    fast_inference=True, # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.6, # Reduce if out of memory
    )

    model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth", # Enable long context finetuning
    random_state=3407,
    )

    import re
    from datasets import load_dataset, Dataset

    # Load and prep dataset
    SYSTEM_PROMPT = """
    Respond in the following format:
    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>
    """

    XML_COT_FORMAT = """\
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

    def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
    return None
    return text.split("####")[1].strip()

    # Uncomment middle messages for 1-shot prompting
    def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

    dataset = get_gsm8k_questions()

    # Reward functions
    def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

    def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
    count += 0.125
    if text.count("\n</reasoning>\n") == 1:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1:
    count += 0.125
    count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    from trl import GRPOConfig, GRPOTrainer
    training_args = GRPOConfig(
    use_vllm=True, # use vLLM for fast inference!
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_steps=25,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1, # Increase to 4 for smoother training
    num_generations=8, # Decrease if out of memory
    max_prompt_length=256,
    max_completion_length=256,
    # num_train_epochs=1, # Set to 1 for a full training run
    max_steps=65536,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="wandb", # Change from "none" to "wandb" to log to wandb
    output_dir="outputs",
    )

    trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
    xmlcount_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
    )
    trainer.train()
    model.save_lora("grpo_saved_lora")