Skip to content

Instantly share code, notes, and snippets.

@HarleyCoops
Forked from willccbb/grpo_demo.py
Created January 29, 2025 00:15
Show Gist options
  • Save HarleyCoops/72aa820cfc5d50c8e78f874ecc51b475 to your computer and use it in GitHub Desktop.
Save HarleyCoops/72aa820cfc5d50c8e78f874ecc51b475 to your computer and use it in GitHub Desktop.

Revisions

  1. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -125,8 +125,8 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_name + "-Instruct")
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
  2. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -61,7 +61,6 @@ def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[floa
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
  3. @willccbb willccbb revised this gist Jan 27, 2025. 1 changed file with 28 additions and 7 deletions.
    35 changes: 28 additions & 7 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -8,7 +8,7 @@
    # Load and prep dataset

    SYSTEM_PROMPT = """
    Respond the in the following format:
    Respond in the following format:
    <reasoning>
    ...
    @@ -18,6 +18,15 @@
    </answer>
    """

    XML_COT_FORMAT = """\
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    @@ -33,6 +42,11 @@ def get_gsm8k_questions(split = "train") -> Dataset:
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
    {'role': 'assistant', 'content': XML_COT_FORMAT.format(
    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
    answer="7"
    )},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    @@ -77,24 +91,31 @@ def count_xml(text) -> float:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
    count += 0.125
    count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    training_args = GRPOConfig(
    output_dir="outputs/Llama-1B-GRPO",
    run_name="Llama-1B-GRPO-gsm8k",
    learning_rate=3e-6,
    output_dir="outputs/Llama-1B-base-GRPO",
    run_name="Llama-1B-base-GRPO-gsm8k",
    learning_rate=1e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.95,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    max_grad_norm=0.001,
    max_grad_norm=0.01,
    report_to="wandb",
    log_on_each_node=False,
    )
    @@ -105,8 +126,8 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_name = "meta-llama/Llama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_name + "-Instruct")
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
  4. @willccbb willccbb revised this gist Jan 26, 2025. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    # grpo_demo.py
    # train_grpo.py
    import re
    from datasets import load_dataset, Dataset
    from transformers import AutoTokenizer
    @@ -94,6 +94,7 @@ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    max_grad_norm=0.001,
    report_to="wandb",
    log_on_each_node=False,
    )
  5. @willccbb willccbb created this gist Jan 26, 2025.
    123 changes: 123 additions & 0 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,123 @@
    # grpo_demo.py
    import re
    from datasets import load_dataset, Dataset
    from transformers import AutoTokenizer
    from peft import LoraConfig
    from trl import GRPOConfig, GRPOTrainer

    # Load and prep dataset

    SYSTEM_PROMPT = """
    Respond the in the following format:
    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

    def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
    return None
    return text.split("####")[1].strip()

    def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

    dataset = get_gsm8k_questions()

    # Reward functions
    def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

    def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
    count += 0.125
    if text.count("\n</reasoning>\n") == 1:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    if text.count("\n</answer>") == 1:
    count += 0.125
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    training_args = GRPOConfig(
    output_dir="outputs/Llama-1B-GRPO",
    run_name="Llama-1B-GRPO-gsm8k",
    learning_rate=3e-6,
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=6,
    num_generations=12,
    max_completion_length=512,
    report_to="wandb",
    log_on_each_node=False,
    )
    peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    )
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    trainer = GRPOTrainer(
    model=model_name,
    processing_class=tokenizer,
    reward_funcs=[
    xmlcount_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
    )
    trainer.train()