Skip to content

Instantly share code, notes, and snippets.

@cgpeter96
Created February 7, 2025 13:53
Show Gist options
  • Select an option

  • Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.

Select an option

Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.

Revisions

  1. cgpeter96 created this gist Feb 7, 2025.
    245 changes: 245 additions & 0 deletions grpo_demo.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,245 @@
    # train_grpo.py
    from typing import *
    import re
    import torch
    from datasets import load_dataset, Dataset, load_from_disk
    from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
    from peft import LoraConfig
    from trl import GRPOConfig, GRPOTrainer, TrlParser
    from dataclasses import dataclass, field

    @dataclass
    class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
    default=None,
    metadata={
    "help": (
    "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
    )
    },
    )
    model_type: Optional[str] = field(
    default=None,
    metadata={"help": "If training from scratch, pass a model type from the list: "},
    )
    config_overrides: Optional[str] = field(
    default=None,
    metadata={
    "help": (
    "Override some existing default config settings when a model is trained from scratch. Example: "
    "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
    )
    },
    )
    config_name: Optional[str] = field(
    default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
    default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
    default=None,
    metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
    default=True,
    metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
    default="main",
    metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
    default=None,
    metadata={
    "help": (
    "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
    "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
    )
    },
    )
    trust_remote_code: bool = field(
    default=False,
    metadata={
    "help": (
    "Whether to trust the execution of code from datasets/models defined on the Hub."
    " This option should only be set to `True` for repositories you trust and in which you have read the"
    " code, as it will execute code present on the Hub on your local machine."
    )
    },
    )
    torch_dtype: Optional[str] = field(
    default=None,
    metadata={
    "help": (
    "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
    "dtype will be automatically derived from the model's weights."
    ),
    "choices": ["auto", "bfloat16", "float16", "float32"],
    },
    )
    low_cpu_mem_usage: bool = field(
    default=False,
    metadata={
    "help": (
    "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
    "set True will benefit LLM loading time and RAM consumption."
    )
    },
    )

    def __post_init__(self):
    if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
    raise ValueError(
    "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
    )

    # Load and prep dataset

    SYSTEM_PROMPT = """
    Respond in the following format:
    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>
    """

    XML_COT_FORMAT = """\
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>
    """

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

    def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
    return None
    return text.split("####")[1].strip()

    # uncomment middle messages for 1-shot prompting
    def get_gsm8k_questions(split = "train") -> Dataset:

    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    # data = load_from_disk("path tp gsm8k")[split] # for local path
    data = data.map(lambda x: { # type: ignore
    'prompt': [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
    #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
    # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
    # answer="7"
    #)},
    {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

    dataset = get_gsm8k_questions()

    # Reward functions
    def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

    def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

    def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
    count += 0.125
    if text.count("\n</reasoning>\n") == 1:
    count += 0.125
    if text.count("\n<answer>\n") == 1:
    count += 0.125
    count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
    count += 0.125
    count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

    def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

    def main(model_args, training_args):


    # peft_config = LoraConfig(
    # r=16,
    # lora_alpha=64,
    # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    # task_type="CAUSAL_LM",
    # lora_dropout=0.05,
    # )
    torch_dtype = (
    model_args.torch_dtype
    if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
    )
    model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    torch_dtype=torch_dtype,
    attn_implementation="flash_attention_2",
    )

    model = model.to("cuda")


    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    # use peft at your own risk; not working for me with multi-GPU training
    trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
    xmlcount_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
    )
    trainer.train()

    if __name__ == "__main__":
    parser = TrlParser((ModelArguments,GRPOConfig,))
    model_args, training_args, = parser.parse_args_and_config()
    main(model_args, training_args)
    29 changes: 29 additions & 0 deletions run_grpo_deepspeed.sh
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,29 @@
    CUDA_HOME=/usr/local/cuda
    gpus=${gpus:-0,1,2,3,4,5,6,7}
    dataset=${dataset:-sft_train_data}
    output_dir=${output_dir:-experiments/$(date +"%Y%m%d_%H%M%S")}

    port=$(shuf -i 10000-20000 -n 1)
    deepspeed --include localhost:${gpus} --master_port=$port grpo_demo.py \
    --deepspeed "ds_zero2.json" \
    --model_name_or_path "path to Qwen2.5-1.5B-Instruct/" \
    --output_dir outputs/Qwen2.5-1.5B-GRPO-gsm8k \
    --run_name Qwen2.5-1.5B-GRPO-gsm8k \
    --learning_rate 1e-5 \
    --adam_beta1 0.9 \
    --adam_beta2 0.99 \
    --weight_decay 0.1 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --bf16 True \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 2 \
    --num_generations 16 \
    --max_prompt_length 512 \
    --max_completion_length 768 \
    --num_train_epochs 5 \
    --save_steps 100 \
    --max_grad_norm 0.1 \
    --report_to tensorboard \
    --log_on_each_node False