- 
            
      
        
      
    Star
      
          
          (1,308)
      
  
You must be signed in to star a gist 
- 
              
      
        
      
    Fork
      
          
          (391)
      
  
You must be signed in to fork a gist 
- 
      
- 
        Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop. 
| # train_grpo.py | |
| import re | |
| from datasets import load_dataset, Dataset | |
| from transformers import AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| # Load and prep dataset | |
| SYSTEM_PROMPT = """ | |
| Respond in the following format: | |
| <reasoning> | |
| ... | |
| </reasoning> | |
| <answer> | |
| ... | |
| </answer> | |
| """ | |
| XML_COT_FORMAT = """\ | |
| <reasoning> | |
| {reasoning} | |
| </reasoning> | |
| <answer> | |
| {answer} | |
| </answer> | |
| """ | |
| def extract_xml_answer(text: str) -> str: | |
| answer = text.split("<answer>")[-1] | |
| answer = answer.split("</answer>")[0] | |
| return answer.strip() | |
| def extract_hash_answer(text: str) -> str | None: | |
| if "####" not in text: | |
| return None | |
| return text.split("####")[1].strip() | |
| def get_gsm8k_questions(split = "train") -> Dataset: | |
| data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
| data = data.map(lambda x: { # type: ignore | |
| 'prompt': [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': 'What is the largest single-digit prime number?'}, | |
| {'role': 'assistant', 'content': XML_COT_FORMAT.format( | |
| reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", | |
| answer="7" | |
| )}, | |
| {'role': 'user', 'content': x['question']} | |
| ], | |
| 'answer': extract_hash_answer(x['answer']) | |
| }) # type: ignore | |
| return data # type: ignore | |
| dataset = get_gsm8k_questions() | |
| # Reward functions | |
| def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
| responses = [completion[0]['content'] for completion in completions] | |
| q = prompts[0][-1]['content'] | |
| extracted_responses = [extract_xml_answer(r) for r in responses] | |
| print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
| return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
| def int_reward_func(completions, **kwargs) -> list[float]: | |
| responses = [completion[0]['content'] for completion in completions] | |
| extracted_responses = [extract_xml_answer(r) for r in responses] | |
| return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
| def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
| """Reward function that checks if the completion has a specific format.""" | |
| pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
| responses = [completion[0]["content"] for completion in completions] | |
| matches = [re.match(pattern, r) for r in responses] | |
| return [0.5 if match else 0.0 for match in matches] | |
| def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
| """Reward function that checks if the completion has a specific format.""" | |
| pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
| responses = [completion[0]["content"] for completion in completions] | |
| matches = [re.match(pattern, r) for r in responses] | |
| return [0.5 if match else 0.0 for match in matches] | |
| def count_xml(text) -> float: | |
| count = 0.0 | |
| if text.count("<reasoning>\n") == 1: | |
| count += 0.125 | |
| if text.count("\n</reasoning>\n") == 1: | |
| count += 0.125 | |
| if text.count("\n<answer>\n") == 1: | |
| count += 0.125 | |
| count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
| if text.count("\n</answer>") == 1: | |
| count += 0.125 | |
| count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
| return count | |
| def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
| contents = [completion[0]["content"] for completion in completions] | |
| return [count_xml(c) for c in contents] | |
| training_args = GRPOConfig( | |
| output_dir="outputs/Llama-1B-base-GRPO", | |
| run_name="Llama-1B-base-GRPO-gsm8k", | |
| learning_rate=1e-6, | |
| adam_beta1 = 0.9, | |
| adam_beta2 = 0.95, | |
| weight_decay = 0.1, | |
| warmup_ratio = 0.1, | |
| lr_scheduler_type='cosine', | |
| logging_steps=1, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=6, | |
| num_generations=12, | |
| max_completion_length=512, | |
| max_grad_norm=0.01, | |
| report_to="wandb", | |
| log_on_each_node=False, | |
| ) | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=64, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
| task_type="CAUSAL_LM", | |
| lora_dropout=0.05, | |
| ) | |
| model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| trainer = GRPOTrainer( | |
| model=model_name, | |
| processing_class=tokenizer, | |
| reward_funcs=[ | |
| xmlcount_reward_func, | |
| soft_format_reward_func, | |
| strict_format_reward_func, | |
| int_reward_func, | |
| correctness_reward_func], | |
| args=training_args, | |
| train_dataset=dataset, | |
| #peft_config=peft_config | |
| ) | |
| trainer.train() | 
Has any one been able to train anything >3B I still cant finetune large models without OOM even with a H100x8
Thanks willccbb!
When training on the GPU with qwen model, I encountered the error:
" probability tensor contains either inf, nan or element < 0"
I would appreciate any insights on a possible solution.
Trained on runpod, took about 90 minutes to do 250 steps
I was able to get it running using a similar setup as @qrdlgit above:
Trained on runpod, took about 90 minutes to do 250 steps
1x H100 NVL (94 GB VRAM)
94 GB RAM • 16 vCPU
Total Disk: 40 GB
pip install git+https://github.com/huggingface/trl.git accelerate transformers datasets peft wandb tqdm
Note that i had to pip install flash_attn as wellBut I had to install a specific version of flash-attn to get it to work:
!pip install flash-attn==2.3.6thx @willccbb.
@ianand Could you please share your detail configs? I just tried many settings but the reward remains at the original level and can't get anying changing
closely related project i'm working on to make RL with verifiers easier: https://github.com/willccbb/verifiers
currently the main focus is on supporting multi-step rollouts (tool use, multi-agent envs, games, code repls, etc)
to make an "environment" for running TRL GRPO (vLLM-only), all that's needed is to extend MultiStepEnv with methods that compute env responses + decide when a trajectory is finished (at the [{'role': 'user', 'content' : ...}] level, no tensor slicing required)
class MultiStepEnv(BaseEnv):
    def __init__(self,
                 system_prompt: str = "",
                 few_shot: List[Dict[str, str]] = [],
                 sampling_args: Dict[str, Any] = {},
                 **kwargs):
        super().__init__(**kwargs)
        self.system_prompt = system_prompt
        self.few_shot = few_shot
        self.sampling_args = sampling_args
    @abstractmethod
    def is_completed(self, messages: List[Dict[str, str]], **kwargs: Any) -> bool:
        pass
    @abstractmethod
    def env_response(self, messages: List[Dict[str, str]], **kwargs: Any) -> Dict[str, str]:
        passwill be adding more env examples in the coming days/weeks
also have some basic support for encapsulating dataset/rubric construction inside envs
if you're a fan of grpo_demo.py, please consider checking it out :)
I think I've spotted an error in the formatting reward regexes. Someone correct me if I'm wrong!
E.g.
r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
Isn't working great because .*? is only selecting characters without newlines, and the reasoning normally includes newlines.
This directly contradicts the other xml reward which is incentivising newlines, i.e. \n</reasoning>\n
My suggestion is instead:
r"<reasoning>[\s\S]*</reasoning>\s*<answer>.*?</answer>"
i.e. the [\s\S]* selects all non-whitespace and whitespace characters. Alternatively can use re.match(pattern, r, flags=re.DOTALL).
@AlexChaloner yes good catch, fixed
为什么会出现越来越多的(\boxed{ANSWE})? 代码对于这个没有任何奖励,为什么训练过程中这个越来越多?
应该是Qwen的预训练数据有这个格式
After implementing some optimizations to the grpo trainer and tweaking params, I'm successfully running training of qwen2.5-0.5B-instruct on a free google colab T4 GPU, at ~13hours/epoch. There's hope for the GPU poor
Will be posting updates here@qunash can you share your parameters?
https://gist.github.com/qunash/820c86d1d267ec8051d9f68b4f4bb656
Somehow mine has near 0% accuracy when running on Qwen2.5 0.5B base (while others have reported 40+%).
I found my model to not conform to the response XML format of , wonder what your prompts are to replicate the results are?
My configs:
training: output_dir: "outputs/Qwen-0.5B-GRPO" run_name: "Qwen-0.5B-GRPO-gsm8k" learning_rate: 5.0e-6 gradient_checkpointing: true # Evaluation settings do_eval: true eval_steps: 50 per_device_eval_batch_size: 128 # Adjust based on your GPU memory eval_strategy: "steps" beta: 0.02My prompt:
prompts: system_prompt: | Respond in the following format: <reasoning> {reasoning} </reasoning> <answer> {answer} </answer> xml_cot_format: | <reasoning> {reasoning} </reasoning> <answer> {answer} </answer>
@andrewsiah Hi,I encountered the same problem, the reward is always 0. Did you solve it?
After implementing some optimizations to the grpo trainer and tweaking params, I'm successfully running training of qwen2.5-0.5B-instruct on a free google colab T4 GPU, at ~13hours/epoch. There's hope for the GPU poor
Will be posting updates here@qunash can you share your parameters?
https://gist.github.com/qunash/820c86d1d267ec8051d9f68b4f4bb656
@qunash
Thx a lot for your sharing!
I have a question about your format reward pattern, it seems only matches a string which equals to \n ?

Hi, I wonder why does the training loss start from 0 during training, Is this related to the reward?
I have a question about your format reward pattern, it seems only matches a string which equals to \n ?
That's because <reasoning>, <answer> and any other tags are hidden when the notebook is rendered as markdown on github.
You need to open the notebook in Colab directly: https://colab.research.google.com/gist/qunash/820c86d1d267ec8051d9f68b4f4bb656/grpo_qwen-0-5b_single_t4.ipynb
I have a question about your format reward pattern, it seems only matches a string which equals to \n ?
That's because
<reasoning>,<answer>and any other tags are hidden when the notebook is rendered as markdown on github. You need to open the notebook in Colab directly: https://colab.research.google.com/gist/qunash/820c86d1d267ec8051d9f68b4f4bb656/grpo_qwen-0-5b_single_t4.ipynb
see, thx a lot for your help !
Has any one been able to train anything >3B I still cant finetune large models without OOM even with a H100x8有人训练过 30 亿以上参数的模型吗?我甚至无法在 H100x8 上微调大型模型,因为会内存不足
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.
Did you solve the problem?
Has any one been able to train anything >3B I still cant finetune large models without OOM even with a H100x8有人训练过 30 亿以上参数的模型吗?我甚至无法在 H100x8 上微调大型模型,因为会内存不足
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.我遇到了同样的问题。我用 batch_size == 1 训练了 7B,但它一直报告 oom。
Did you solve the problem?你解决这个问题了吗?
me, too!!I also issue my problem in the trl's github. If you solve the problem, please help me.
I used an A100(40 GB) on the colab. But it can barely support the training of a 1.5B model.
And I found that GRPO does not support efficient fine-tuning, and can only fine-tune all parameters,issue, Can anyone solve this problem?or.. Are there other reinforcement learning frameworks that support efficient parameter fine-tuning(peft)?
After about 3 hours of fine-tuning, the effect of the 1.5B model is indeed better than that of the 0.5B model.
 
a 1.5B model can be trained for nearly 460 steps in 3 hours with an a100
Can someone share the code of evaluating the model on Test set of gsm8k? Thanks a lot!!!!!!有人能提供一下在gsm'8k数据的测试集上评估模型训练结果的代码吗?非常感谢!
how to adjust the params with just two A10?
@willccbb, why would we prefer separate reward functions instead of having a single unified one in GRPO?
being able to log individual rewards is pretty useful for debugging imo
consolidating them into one shouldn't affect actual training dynamics though
IMHO separate, additive rewards introduce a lot of repetition (e.g., parsing responses) and limit creativity in reward design, e.g., I may want the formatting reward to be a gate for the others, as I may not even want to evaluate a response if the formatting is wrong.
definitely get creative with it! nothing wrong with using if statements + multiplication in your reward functions
Does it also work on smaller models like >3B params model?
When training on the GPU with qwen model, I encountered the error: " probability tensor contains either
inf,nanor element < 0"
Hi @fsxbhyy, did you load the model in torch.bfloat16? I used to encounter such issue when I loaded models in torch.float16 instead of bfloat. I guess float16 in this context leads to numerical instability, leading to NaN probs. Hope this helps!
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.
@harrywoo @Tuziking I had the same problem. I then noticed that these values are actually huge for most cases:
max_prompt_length=256,
max_completion_length=786,
786 generated tokens to process per generation requires a lot of memory, especially if your group size is large. Try to set this to 150 or 250 and see if it reduces memory usage. Hope this helps!
After some tuning and training on gsm8k train set (7.47k examples). Model after GRPO scores
51%~on gsm8k test setQwen2.5-0.5Bvs base model:41.6%as reported in their paper.在 gsm8k 火车集(7.47k 示例)上进行了一些调整和训练后。GRPO 后的模型在 gsm8k 测试集Qwen2.5-0.5B上得分为51%~,而基本模型:41.6%,如他们的论文所示。Changes: 变化:
- Tune beta to 0.01将 beta 调整为 0.01
- LR 2e-6
- num_generations = 8
- 0.07 warmup, cosine, no WD0.07 预热,余弦,无 WD
- 1x4x8devices (32 total batch size)1x4x8 设备(总共 32 个批量大小)
- max completion length 512最大完成长度 512
- use only a system prompt仅使用系统提示符
- evaluated with greedy decoding on vllm在 VLLM 上使用贪婪解码进行评估
Are you using Qwen-2.5-0.5B-Instruct as your base model? I noticed in Table 10 of the Qwen2.5-technical-report that Qwen-2.5-0.5B-Instruct scores 49.6 on GSM8K, and you mentioned your trained model achieved ~51%. From this perspective, it seems there wasn't a significant performance improvement. Please correct me if my assessment is wrong.
To be honest, I tried using Qwen2.5-1.5B-Instruct as the base model to train Qwen-1.5B-GRPO, and its performance on GSM8K was 73.24, which is almost identical to what's reported in the Qwen2.5 technical report. However, I did notice that the training brought format-related benefits. At the beginning of training, the model struggled to follow the output format required in the SYSTEM_PROMPT (<reasoning>...</reasoning><answer>...</answer>), but after training, the model could follow this format almost perfectly. This indicates that the training did bring certain benefits—but in my experiment, the improvements were primarily in formatting rather than solving ability. Do you have any insights on this?
Hey has anyone been able to get the llama model to work? I ask because i tried running the llama model but it would not format the answer correctly (which is an issue given how the reward functions are computed). I was able to get things to work better by running the 3B parameter instruct model--but I was curious about whether things should also work for the 1B parameter model.
I, admittedly, only trained the 1B parameter model for ~30 steps (again all zero rewards for all 30 steps) before switching to the 3B parameter model.
Also nice work Will!
you guys are some serious prompt wizards 🪄🪄 and we need you sharing the knowledge at God Tier Prompts!






do you mean adam_beta1 = 0.01 and adam_beta2 = 0.01 ? the default values are 0.9 and 0.99