Created
February 7, 2025 13:53
-
-
Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.
Revisions
-
cgpeter96 created this gist
Feb 7, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,245 @@ # train_grpo.py from typing import * import re import torch from datasets import load_dataset, Dataset, load_from_disk from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments from peft import LoraConfig from trl import GRPOConfig, GRPOTrainer, TrlParser from dataclasses import dataclass, field @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: "}, ) config_overrides: Optional[str] = field( default=None, metadata={ "help": ( "Override some existing default config settings when a model is trained from scratch. Example: " "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" ) }, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) use_fast_tokenizer: bool = field( default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) token: str = field( default=None, metadata={ "help": ( "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." ) }, ) trust_remote_code: bool = field( default=False, metadata={ "help": ( "Whether to trust the execution of code from datasets/models defined on the Hub." " This option should only be set to `True` for repositories you trust and in which you have read the" " code, as it will execute code present on the Hub on your local machine." ) }, ) torch_dtype: Optional[str] = field( default=None, metadata={ "help": ( "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) low_cpu_mem_usage: bool = field( default=False, metadata={ "help": ( "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " "set True will benefit LLM loading time and RAM consumption." ) }, ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): raise ValueError( "--config_overrides can't be used in combination with --config_name or --model_name_or_path" ) # Load and prep dataset SYSTEM_PROMPT = """ Respond in the following format: <reasoning> ... </reasoning> <answer> ... </answer> """ XML_COT_FORMAT = """\ <reasoning> {reasoning} </reasoning> <answer> {answer} </answer> """ def extract_xml_answer(text: str) -> str: answer = text.split("<answer>")[-1] answer = answer.split("</answer>")[0] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() # uncomment middle messages for 1-shot prompting def get_gsm8k_questions(split = "train") -> Dataset: data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore # data = load_from_disk("path tp gsm8k")[split] # for local path data = data.map(lambda x: { # type: ignore 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, #{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, #{'role': 'assistant', 'content': XML_COT_FORMAT.format( # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", # answer="7" #)}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) # type: ignore return data # type: ignore dataset = get_gsm8k_questions() # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("<reasoning>\n") == 1: count += 0.125 if text.count("\n</reasoning>\n") == 1: count += 0.125 if text.count("\n<answer>\n") == 1: count += 0.125 count -= len(text.split("\n</answer>\n")[-1])*0.001 if text.count("\n</answer>") == 1: count += 0.125 count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] def main(model_args, training_args): # peft_config = LoraConfig( # r=16, # lora_alpha=64, # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], # task_type="CAUSAL_LM", # lora_dropout=0.05, # ) torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ) model = model.to("cuda") tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) tokenizer.pad_token = tokenizer.eos_token # use peft at your own risk; not working for me with multi-GPU training trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func], args=training_args, train_dataset=dataset, #peft_config=peft_config ) trainer.train() if __name__ == "__main__": parser = TrlParser((ModelArguments,GRPOConfig,)) model_args, training_args, = parser.parse_args_and_config() main(model_args, training_args) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,29 @@ CUDA_HOME=/usr/local/cuda gpus=${gpus:-0,1,2,3,4,5,6,7} dataset=${dataset:-sft_train_data} output_dir=${output_dir:-experiments/$(date +"%Y%m%d_%H%M%S")} port=$(shuf -i 10000-20000 -n 1) deepspeed --include localhost:${gpus} --master_port=$port grpo_demo.py \ --deepspeed "ds_zero2.json" \ --model_name_or_path "path to Qwen2.5-1.5B-Instruct/" \ --output_dir outputs/Qwen2.5-1.5B-GRPO-gsm8k \ --run_name Qwen2.5-1.5B-GRPO-gsm8k \ --learning_rate 1e-5 \ --adam_beta1 0.9 \ --adam_beta2 0.99 \ --weight_decay 0.1 \ --warmup_ratio 0.1 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --bf16 True \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 2 \ --num_generations 16 \ --max_prompt_length 512 \ --max_completion_length 768 \ --num_train_epochs 5 \ --save_steps 100 \ --max_grad_norm 0.1 \ --report_to tensorboard \ --log_on_each_node False