Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save data2json/effdcde9ba41bf37bb538aa0717c11ce to your computer and use it in GitHub Desktop.
Save data2json/effdcde9ba41bf37bb538aa0717c11ce to your computer and use it in GitHub Desktop.

Revisions

  1. data2json renamed this gist Feb 1, 2025. 1 changed file with 0 additions and 0 deletions.
  2. data2json created this gist Feb 1, 2025.
    184 changes: 184 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,184 @@
    import re
    import torch
    import numpy as np
    from datasets import Dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import LoraConfig
    from trl import GRPOConfig, GRPOTrainer

    class SimpleCalculator:
    """Simple calculator tool implementation"""
    def __call__(self, expression: str) -> float:
    # Clean the expression and evaluate
    clean_expr = re.sub(r'[^0-9+\-*/().]', '', expression)
    try:
    return float(eval(clean_expr)) # Safe since we cleaned the input
    except:
    return 0.0

    # System prompt demonstrating tool usage
    SYSTEM_PROMPT = """
    To solve calculations, use the SimpleCalculatorTool in this format:
    <request><SimpleCalculatorTool>[expression]<call>[result]<response>
    Then provide the final answer as:
    Result=[number]<submit>

    Example:
    What is 13-3?
    <request><SimpleCalculatorTool>13-3<call>10.0<response>
    Result=10<submit>

    What is 4*3?
    <request><SimpleCalculatorTool>4*3<call>12.0<response>
    Result=12<submit>
    """

    def generate_data(n):
    """Generate random arithmetic tasks and answers."""
    tasks = []
    calculator = SimpleCalculator()

    for _ in range(n):
    a = np.random.randint(0, 50)
    b = np.random.randint(0, 50)
    op = np.random.choice(["-", "+", "*"])
    expression = f"{a}{op}{b}"

    # Get answer using the calculator
    result = calculator(expression)

    tasks.append({
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': f"What is {expression}?"}
    ],
    'answer': str(result)
    })
    return Dataset.from_list(tasks)

    def extract_tool_usage(text: str) -> tuple[bool, str | None]:
    """Check for proper tool usage and extract final answer."""
    # Check for proper tool request format
    tool_pattern = r"<request><SimpleCalculatorTool>(.*?)<call>(.*?)<response>"
    tool_match = re.search(tool_pattern, text)

    # Check for final answer format
    answer_pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>"
    answer_match = re.search(answer_pattern, text)

    used_tool = bool(tool_match)
    final_answer = answer_match.group(1) if answer_match else None

    return used_tool, final_answer

    def batch_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Reward function that checks for proper tool usage and correct answers.
    """
    rewards = []
    calculator = SimpleCalculator()

    for completion, ans in zip(completions, answer):
    response = completion[0]['content']
    reward = 0.0

    # Check tool usage and extract answer
    used_tool, predicted = extract_tool_usage(response)

    # 1. Tool usage reward (1.0)
    if used_tool:
    reward += 1.0

    # 2. Correctness reward (2.0)
    if predicted and abs(float(predicted) - float(ans)) < 0.01:
    reward += 2.0

    # 3. Format reward (0.5)
    if "<submit>" in response:
    reward += 0.5

    rewards.append(reward)

    # Debug printing
    print('-'*20)
    print(f"Response:\n{response}")
    print(f"Used tool: {used_tool}")
    print(f"Predicted: {predicted}")
    print(f"Answer: {ans}")
    print(f"Reward: {reward}")

    return rewards

    # Model configuration
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"
    output_dir = "outputs/Qwen-0.5B-GRPO-calculator-tool"
    run_name = "Qwen-0.5B-GRPO-calculator-tool"

    # Training configuration
    training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,

    # Optimizer settings
    learning_rate=1.41e-5,
    optim="adamw_torch_fused",

    # Training settings
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=64,
    num_train_epochs=1,
    save_steps=100,

    # Generation settings
    num_generations=8,
    max_prompt_length=128,
    max_completion_length=128,

    # Memory optimizations
    gradient_checkpointing=True,

    # Reporting
    report_to="wandb",
    )

    # LoRA configuration
    peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Qwen attention modules
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    ).to("cuda")

    # Enable optimizations
    model.config.use_cache = False
    model.gradient_checkpointing_enable()

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Generate dataset with examples using the calculator tool
    dataset = generate_data(1000)

    # Initialize trainer
    trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[batch_reward_func],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config
    )

    # Start training
    if __name__ == "__main__":
    trainer.train()