# train_grpo.py from typing import * import re import torch from datasets import load_dataset, Dataset, load_from_disk from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments from peft import LoraConfig from trl import GRPOConfig, GRPOTrainer, TrlParser from dataclasses import dataclass, field @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: "}, ) config_overrides: Optional[str] = field( default=None, metadata={ "help": ( "Override some existing default config settings when a model is trained from scratch. Example: " "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" ) }, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) use_fast_tokenizer: bool = field( default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) token: str = field( default=None, metadata={ "help": ( "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." ) }, ) trust_remote_code: bool = field( default=False, metadata={ "help": ( "Whether to trust the execution of code from datasets/models defined on the Hub." " This option should only be set to `True` for repositories you trust and in which you have read the" " code, as it will execute code present on the Hub on your local machine." ) }, ) torch_dtype: Optional[str] = field( default=None, metadata={ "help": ( "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) low_cpu_mem_usage: bool = field( default=False, metadata={ "help": ( "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " "set True will benefit LLM loading time and RAM consumption." ) }, ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): raise ValueError( "--config_overrides can't be used in combination with --config_name or --model_name_or_path" ) # Load and prep dataset SYSTEM_PROMPT = """ Respond in the following format: ... ... """ XML_COT_FORMAT = """\ {reasoning} {answer} """ def extract_xml_answer(text: str) -> str: answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() # uncomment middle messages for 1-shot prompting def get_gsm8k_questions(split = "train") -> Dataset: data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore # data = load_from_disk("path tp gsm8k")[split] # for local path data = data.map(lambda x: { # type: ignore 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, #{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, #{'role': 'assistant', 'content': XML_COT_FORMAT.format( # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", # answer="7" #)}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) # type: ignore return data # type: ignore dataset = get_gsm8k_questions() # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] def main(model_args, training_args): # peft_config = LoraConfig( # r=16, # lora_alpha=64, # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], # task_type="CAUSAL_LM", # lora_dropout=0.05, # ) torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ) model = model.to("cuda") tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) tokenizer.pad_token = tokenizer.eos_token # use peft at your own risk; not working for me with multi-GPU training trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func], args=training_args, train_dataset=dataset, #peft_config=peft_config ) trainer.train() if __name__ == "__main__": parser = TrlParser((ModelArguments,GRPOConfig,)) model_args, training_args, = parser.parse_args_and_config() main(model_args, training_args)