Created
February 1, 2025 22:04
-
-
Save data2json/effdcde9ba41bf37bb538aa0717c11ce to your computer and use it in GitHub Desktop.
GRPO-Function_Calling_Qwen2.5-0.5B-Instruct
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import torch | |
| import numpy as np | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| class SimpleCalculator: | |
| """Simple calculator tool implementation""" | |
| def __call__(self, expression: str) -> float: | |
| # Clean the expression and evaluate | |
| clean_expr = re.sub(r'[^0-9+\-*/().]', '', expression) | |
| try: | |
| return float(eval(clean_expr)) # Safe since we cleaned the input | |
| except: | |
| return 0.0 | |
| # System prompt demonstrating tool usage | |
| SYSTEM_PROMPT = """ | |
| To solve calculations, use the SimpleCalculatorTool in this format: | |
| <request><SimpleCalculatorTool>[expression]<call>[result]<response> | |
| Then provide the final answer as: | |
| Result=[number]<submit> | |
| Example: | |
| What is 13-3? | |
| <request><SimpleCalculatorTool>13-3<call>10.0<response> | |
| Result=10<submit> | |
| What is 4*3? | |
| <request><SimpleCalculatorTool>4*3<call>12.0<response> | |
| Result=12<submit> | |
| """ | |
| def generate_data(n): | |
| """Generate random arithmetic tasks and answers.""" | |
| tasks = [] | |
| calculator = SimpleCalculator() | |
| for _ in range(n): | |
| a = np.random.randint(0, 50) | |
| b = np.random.randint(0, 50) | |
| op = np.random.choice(["-", "+", "*"]) | |
| expression = f"{a}{op}{b}" | |
| # Get answer using the calculator | |
| result = calculator(expression) | |
| tasks.append({ | |
| 'prompt': [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': f"What is {expression}?"} | |
| ], | |
| 'answer': str(result) | |
| }) | |
| return Dataset.from_list(tasks) | |
| def extract_tool_usage(text: str) -> tuple[bool, str | None]: | |
| """Check for proper tool usage and extract final answer.""" | |
| # Check for proper tool request format | |
| tool_pattern = r"<request><SimpleCalculatorTool>(.*?)<call>(.*?)<response>" | |
| tool_match = re.search(tool_pattern, text) | |
| # Check for final answer format | |
| answer_pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" | |
| answer_match = re.search(answer_pattern, text) | |
| used_tool = bool(tool_match) | |
| final_answer = answer_match.group(1) if answer_match else None | |
| return used_tool, final_answer | |
| def batch_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
| """ | |
| Reward function that checks for proper tool usage and correct answers. | |
| """ | |
| rewards = [] | |
| calculator = SimpleCalculator() | |
| for completion, ans in zip(completions, answer): | |
| response = completion[0]['content'] | |
| reward = 0.0 | |
| # Check tool usage and extract answer | |
| used_tool, predicted = extract_tool_usage(response) | |
| # 1. Tool usage reward (1.0) | |
| if used_tool: | |
| reward += 1.0 | |
| # 2. Correctness reward (2.0) | |
| if predicted and abs(float(predicted) - float(ans)) < 0.01: | |
| reward += 2.0 | |
| # 3. Format reward (0.5) | |
| if "<submit>" in response: | |
| reward += 0.5 | |
| rewards.append(reward) | |
| # Debug printing | |
| print('-'*20) | |
| print(f"Response:\n{response}") | |
| print(f"Used tool: {used_tool}") | |
| print(f"Predicted: {predicted}") | |
| print(f"Answer: {ans}") | |
| print(f"Reward: {reward}") | |
| return rewards | |
| # Model configuration | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| output_dir = "outputs/Qwen-0.5B-GRPO-calculator-tool" | |
| run_name = "Qwen-0.5B-GRPO-calculator-tool" | |
| # Training configuration | |
| training_args = GRPOConfig( | |
| output_dir=output_dir, | |
| run_name=run_name, | |
| # Optimizer settings | |
| learning_rate=1.41e-5, | |
| optim="adamw_torch_fused", | |
| # Training settings | |
| logging_steps=1, | |
| bf16=True, | |
| per_device_train_batch_size=4, | |
| gradient_accumulation_steps=64, | |
| num_train_epochs=1, | |
| save_steps=100, | |
| # Generation settings | |
| num_generations=8, | |
| max_prompt_length=128, | |
| max_completion_length=128, | |
| # Memory optimizations | |
| gradient_checkpointing=True, | |
| # Reporting | |
| report_to="wandb", | |
| ) | |
| # LoRA configuration | |
| peft_config = LoraConfig( | |
| r=8, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Qwen attention modules | |
| task_type="CAUSAL_LM", | |
| lora_dropout=0.05, | |
| ) | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| # Enable optimizations | |
| model.config.use_cache = False | |
| model.gradient_checkpointing_enable() | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Generate dataset with examples using the calculator tool | |
| dataset = generate_data(1000) | |
| # Initialize trainer | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[batch_reward_func], | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=peft_config | |
| ) | |
| # Start training | |
| if __name__ == "__main__": | |
| trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment