-
-
Save spate141/3a8a6a1fc38bf69db78d04b07afcc6dd to your computer and use it in GitHub Desktop.
GRPO Llama-1B
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # train_grpo.py - For single 4090 24GB PEFT config | |
| # Import packages | |
| import os | |
| import re | |
| import torch | |
| import random | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig, get_peft_model | |
| from trl import GRPOConfig, GRPOTrainer | |
| os.environ['WANDB_NOTEBOOK_NAME'] = '20250201_trial_3' | |
| # Load and prep dataset | |
| def format_reward_func(completions, **kwargs): | |
| pattern = r"\n#### The final answer is \d+" | |
| completion_contents = [completion for completion in completions] | |
| matches = [re.search(pattern, content) for content in completion_contents] | |
| return [0.5 if match else 0.0 for match in matches] | |
| def extract_xml_answer(text: str) -> str: | |
| if "<answer>" in text and "</answer>" in text: | |
| answer = text.split("<answer>")[-1].split("</answer>")[0].strip() | |
| else: | |
| answer = text.strip() # Fallback: Use the full text if tags are missing | |
| return answer | |
| def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
| responses = [completion[0]['content'] for completion in completions] | |
| q = prompts[0][-1]['content'] | |
| extracted_responses = [extract_xml_answer(r) for r in responses] | |
| # Debugging print | |
| print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
| # Use numerical similarity for rewards | |
| rewards = [] | |
| for r, a in zip(extracted_responses, answer): | |
| try: | |
| # Convert to float for better comparison | |
| r_num, a_num = float(r), float(a) | |
| if abs(r_num - a_num) < 1e-3: | |
| rewards.append(2.0) # Exact match | |
| elif abs(r_num - a_num) < 0.1: | |
| rewards.append(1.5) # Very close | |
| elif abs(r_num - a_num) < 1.0: | |
| rewards.append(1.0) # Somewhat close | |
| else: | |
| rewards.append(0.0) # Wrong | |
| except ValueError: | |
| rewards.append(0.0) # Failed to convert -> Wrong answer | |
| return rewards | |
| class GSM8K: | |
| def __init__( | |
| self, split, include_answer=True, include_reasoning=True, few_shot=False, | |
| num_shots=8, seed=None, cot=False, template="qa" | |
| ): | |
| self.split = split | |
| self.include_answer = include_answer | |
| self.include_reasoning = include_reasoning | |
| self.seed = seed | |
| if self.seed is not None: | |
| random.seed(self.seed) | |
| self.few_shot = few_shot | |
| self.num_shots = num_shots | |
| self.cot = cot | |
| self.template = template | |
| self.examples = [ | |
| { | |
| "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", | |
| "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.", | |
| "short_answer": "6" | |
| }, | |
| { | |
| "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", | |
| "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.", | |
| "short_answer": "5" | |
| }, | |
| { | |
| "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", | |
| "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.", | |
| "short_answer": "39" | |
| }, | |
| { | |
| "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", | |
| "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.", | |
| "short_answer": "8" | |
| }, | |
| { | |
| "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", | |
| "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.", | |
| "short_answer": "9" | |
| }, | |
| { | |
| "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", | |
| "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.", | |
| "short_answer": "29" | |
| }, | |
| { | |
| "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", | |
| "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.", | |
| "short_answer": "33" | |
| }, | |
| { | |
| "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", | |
| "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.", | |
| "short_answer": "8" | |
| } | |
| ] | |
| self.dataset = self.load_dataset() | |
| def format_example(self, question, solution, answer): | |
| example = '' | |
| if self.template == 'qa': | |
| example = f"Question: {question}\nSolution: " | |
| if self.cot: | |
| example += "Let's break it down step by step. " | |
| # example += '\n' | |
| if solution is not None: | |
| def remove_placeholders(text): | |
| import re | |
| # Regex to match <<anything>> | |
| cleaned_text = re.sub(r'<<.*?>>', '', text) | |
| return cleaned_text | |
| solution = '. '.join(solution.split('\n')) | |
| solution = remove_placeholders(solution) | |
| example += f"{solution}.\n" | |
| example = example.replace('..', '.') | |
| if answer is not None: | |
| example += f"#### The final answer is {answer}\n\n" | |
| else: | |
| raise ValueError('Format Not Implemented') | |
| return example | |
| def process_example(self, example, index): | |
| question = example['question'] | |
| answer = example['answer'] | |
| # Extract the reasoning steps and the final answer | |
| answer_delim = "#### " | |
| if answer_delim in answer: | |
| reasoning = answer.split(answer_delim)[0].strip() | |
| final_answer = answer.split(answer_delim)[-1].strip() | |
| else: | |
| reasoning = answer.strip() | |
| final_answer = '' | |
| # Create the prompt | |
| if self.include_answer: | |
| if self.include_reasoning: | |
| input_text = self.format_example(question, reasoning, final_answer) | |
| else: | |
| input_text = self.format_example(question, None, final_answer) | |
| else: | |
| input_text = self.format_example(question, None, None) | |
| if self.few_shot: | |
| input_text = self.few_shot_prompt + input_text | |
| return { | |
| 'prompt': input_text, | |
| 'final_answer': final_answer, | |
| 'question': question, | |
| } | |
| def load_dataset(self): | |
| # Load the GSM8K dataset with the specified split | |
| dataset = load_dataset('gsm8k', 'main', split=self.split) | |
| if self.few_shot: | |
| self.few_shot_prompt = self.build_prompt() | |
| dataset = dataset.map(self.process_example, with_indices=True, load_from_cache_file=False) | |
| return dataset | |
| def fewshot_examples_qa(self): | |
| return self.examples | |
| def make_prompts(self): | |
| """Builds the prompt for the LM to generate from.""" | |
| if self.template == 'qa': | |
| examples = self.fewshot_examples_qa() | |
| else: | |
| raise ValueError('Format Not Implemented') | |
| self.examples = examples | |
| def build_prompt(self): | |
| if self.examples is None: | |
| self.make_prompts() | |
| prompt = "" | |
| for qna in random.sample(self.examples, self.num_shots): | |
| prompt += self.format_example(qna['question'], qna['cot_answer'], qna['short_answer']) | |
| return prompt | |
| dataset = GSM8K( | |
| split='train', include_answer=False, include_reasoning=True, | |
| few_shot=True, num_shots=2, seed=None, cot=True, template="qa" | |
| ).dataset.shuffle(seed=42) | |
| # Model | |
| model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
| # model_name = "Qwen/Qwen2.5-Math-1.5B" | |
| output_dir = f'/mnt/d/outputs/GRPO/{model_name}' | |
| training_args = GRPOConfig( | |
| output_dir=output_dir, | |
| run_name=f'GRPO-GSM8K-{model_name.split('/')[-1]}', | |
| learning_rate=2e-6, | |
| logging_steps=1, | |
| bf16=True, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| num_generations=2, | |
| max_prompt_length=256, | |
| max_completion_length=256, | |
| num_train_epochs=1, | |
| save_steps=100, | |
| max_grad_norm=0.1, | |
| report_to='wandb', | |
| log_on_each_node=False, | |
| # use_vllm=True, | |
| # vllm_device='auto', | |
| warmup_ratio = 0.07, | |
| beta=0.2 | |
| ) | |
| rank = 8 | |
| peft_config = LoraConfig( | |
| r=rank, | |
| lora_alpha=rank*2, | |
| # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
| target_modules=["q_proj", "v_proj", "o_proj"], # Fewer layers for LoRA | |
| task_type="CAUSAL_LM", | |
| bias='lora_only', | |
| lora_dropout=0.08, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| device_map='auto' | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[ | |
| format_reward_func, | |
| correctness_reward_func | |
| ], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |
| model.save_pretrained(output_dir) | |
| print(f"LoRA model and configuration saved to {output_dir}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment