Skip to content

Instantly share code, notes, and snippets.

@XAheli
Forked from willccbb/grpo_demo.py
Created February 1, 2025 09:38
Show Gist options
  • Save XAheli/50b42dc1ae9dbee7ad02a68b2282f109 to your computer and use it in GitHub Desktop.
Save XAheli/50b42dc1ae9dbee7ad02a68b2282f109 to your computer and use it in GitHub Desktop.

Revisions

  1. @willccbb willccbb revised this gist Jan 29, 2025. 1 changed file with 41 additions and 17 deletions.
    58 changes: 41 additions & 17 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -1,7 +1,8 @@
    # train_grpo.py
    import re
    import torch
    from datasets import load_dataset, Dataset
    from transformers import AutoTokenizer
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import LoraConfig
    from trl import GRPOConfig, GRPOTrainer

    @@ -37,16 +38,17 @@ def extract_hash_answer(text: str) -> str | None:
    return None
    return text.split("####")[1].strip()

    # uncomment middle messages for 1-shot prompting
    def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
    {'role': 'assistant', 'content': XML_COT_FORMAT.format(
    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
    answer="7"
    )},
    #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
    #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
    # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
    # answer="7"
    #)},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    @@ -81,7 +83,7 @@ def soft_format_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
    @@ -100,21 +102,35 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    #model_name = "meta-llama/Llama-3.2-1B-Instruct"
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"

    if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
    else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"

    training_args = GRPOConfig(
    output_dir="outputs/Llama-1B-base-GRPO",
    run_name="Llama-1B-base-GRPO-gsm8k",
    learning_rate=1e-6,
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.95,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    max_grad_norm=0.01,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
    )
    @@ -125,11 +141,19 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # use peft at your own risk; not working for me with multi-GPU training
    trainer = GRPOTrainer(
    model=model_name,
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
    xmlcount_reward_func,
  2. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -125,8 +125,8 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_name + "-Instruct")
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
  3. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -61,7 +61,6 @@ def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[floa
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
  4. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 28 additions and 7 deletions.
    35 changes: 28 additions & 7 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -8,7 +8,7 @@
    # Load and prep dataset

    SYSTEM_PROMPT = """
    Respond the in the following format:
    Respond in the following format:
    <reasoning>
    ...
    @@ -18,6 +18,15 @@
    </answer>
    """

    XML_COT_FORMAT = """\
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    @@ -33,6 +42,11 @@ def get_gsm8k_questions(split = "train") -> Dataset:
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
    {'role': 'assistant', 'content': XML_COT_FORMAT.format(
    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
    answer="7"
    )},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    @@ -77,24 +91,31 @@ def count_xml(text) -> float:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
    count += 0.125
    count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    training_args = GRPOConfig(
    output_dir="outputs/Llama-1B-GRPO",
    run_name="Llama-1B-GRPO-gsm8k",
    learning_rate=3e-6,
    output_dir="outputs/Llama-1B-base-GRPO",
    run_name="Llama-1B-base-GRPO-gsm8k",
    learning_rate=1e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.95,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    max_grad_norm=0.001,
    max_grad_norm=0.01,
    report_to="wandb",
    log_on_each_node=False,
    )
    @@ -105,8 +126,8 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_name = "meta-llama/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_name + "-Instruct")
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
  5. @willccbb willccbb revised this gist Jan 26, 2025. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    # grpo_demo.py
    # train_grpo.py
    import re
    from datasets import load_dataset, Dataset
    from transformers import AutoTokenizer
    @@ -94,6 +94,7 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    max_grad_norm=0.001,
    report_to="wandb",
    log_on_each_node=False,
    )
  6. @willccbb willccbb created this gist Jan 26, 2025.
    123 changes: 123 additions & 0 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,123 @@
    # grpo_demo.py
    import re
    from datasets import load_dataset, Dataset
    from transformers import AutoTokenizer
    from peft import LoraConfig
    from trl import GRPOConfig, GRPOTrainer

    # Load and prep dataset

    SYSTEM_PROMPT = """
    Respond the in the following format:
    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

    def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
    return None
    return text.split("####")[1].strip()

    def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

    dataset = get_gsm8k_questions()

    # Reward functions
    def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

    def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
    count += 0.125
    if text.count("\n</reasoning>\n") == 1:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    if text.count("\n</answer>") == 1:
    count += 0.125
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    training_args = GRPOConfig(
    output_dir="outputs/Llama-1B-GRPO",
    run_name="Llama-1B-GRPO-gsm8k",
    learning_rate=3e-6,
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    report_to="wandb",
    log_on_each_node=False,
    )
    peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
    processing_class=tokenizer,
    reward_funcs=[
    xmlcount_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
    )
    trainer.train()