Created
February 11, 2025 23:17
-
-
Save jarrelscy/c35bccf971b08923d245cbf616554c79 to your computer and use it in GitHub Desktop.
Revisions
-
jarrelscy renamed this gist
Feb 11, 2025 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
jarrelscy created this gist
Feb 11, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,167 @@ from unsloth import FastLanguageModel, PatchFastRL PatchFastRL("GRPO", FastLanguageModel) from unsloth import is_bfloat16_supported import torch import wandb # Import wandb for logging # Initialize wandb with your project name wandb.init(project="unsloth-grpo") max_seq_length = 1024 # Can increase for longer reasoning traces lora_rank = 64 # Larger rank = smarter, but slower model, tokenizer = FastLanguageModel.from_pretrained( model_name = "Qwen/Qwen2.5-3B-Instruct", max_seq_length=max_seq_length, load_in_4bit=True, # False for LoRA 16bit fast_inference=True, # Enable vLLM fast inference max_lora_rank=lora_rank, gpu_memory_utilization=0.6, # Reduce if out of memory ) model = FastLanguageModel.get_peft_model( model, r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], # Remove QKVO if out of memory lora_alpha=lora_rank, use_gradient_checkpointing="unsloth", # Enable long context finetuning random_state=3407, ) import re from datasets import load_dataset, Dataset # Load and prep dataset SYSTEM_PROMPT = """ Respond in the following format: <reasoning> ... </reasoning> <answer> ... </answer> """ XML_COT_FORMAT = """\ <reasoning> {reasoning} </reasoning> <answer> {answer} </answer> """ def extract_xml_answer(text: str) -> str: answer = text.split("<answer>")[-1] answer = answer.split("</answer>")[0] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() # Uncomment middle messages for 1-shot prompting def get_gsm8k_questions(split="train") -> Dataset: data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore data = data.map(lambda x: { # type: ignore 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) # type: ignore return data # type: ignore dataset = get_gsm8k_questions() # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("<reasoning>\n") == 1: count += 0.125 if text.count("\n</reasoning>\n") == 1: count += 0.125 if text.count("\n<answer>\n") == 1: count += 0.125 count -= len(text.split("\n</answer>\n")[-1]) * 0.001 if text.count("\n</answer>") == 1: count += 0.125 count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( use_vllm=True, # use vLLM for fast inference! learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_steps=25, lr_scheduler_type="cosine", optim="adamw_8bit", logging_steps=1, bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), per_device_train_batch_size=1, gradient_accumulation_steps=1, # Increase to 4 for smoother training num_generations=8, # Decrease if out of memory max_prompt_length=256, max_completion_length=256, # num_train_epochs=1, # Set to 1 for a full training run max_steps=65536, save_steps=250, max_grad_norm=0.1, report_to="wandb", # Change from "none" to "wandb" to log to wandb output_dir="outputs", ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func, ], args=training_args, train_dataset=dataset, ) trainer.train() model.save_lora("grpo_saved_lora")