Last active
February 25, 2025 22:52
-
-
Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.
Revisions
-
abacaj revised this gist
Feb 5, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -75,7 +75,7 @@ def tokenize_validation(tokenizer, samples, max_prompt_length): ids = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, truncation=False, max_length=max_prompt_length, ) tokenized_samples.append((ids, answer)) -
abacaj revised this gist
Feb 5, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,5 @@ import tqdm import numpy as np import torch import torch.distributed as dist import transformers -
abacaj revised this gist
Feb 5, 2025 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -81,7 +81,7 @@ def tokenize_validation(tokenizer, samples, max_prompt_length): tokenized_samples.append((ids, answer)) return tokenized_samples class CustomTrainer(transformers.GRPOTrainer): def evaluate( self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval" ): @@ -100,7 +100,7 @@ def evaluate( return output training_args = transformers.GRPOConfig( output_dir=f"checkpoints/qwen25-05b", bf16=True, max_prompt_length=356, -
abacaj revised this gist
Feb 5, 2025 . 1 changed file with 3 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -113,6 +113,9 @@ def evaluate( eval_strategy="steps" ) dataset = get_gsm8k_questions() test_dataset = get_gsm8k_questions("test") trainer = CustomTrainer( model=model, processing_class=tokenizer, -
abacaj created this gist
Feb 5, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,127 @@ import tqdm import numpy as npy import torch import torch.distributed as dist import transformers def extract_xml_answer(text: str) -> str: answer = text.split("<final_answer>")[-1] answer = answer.split("</final_answer>")[0] return answer.strip() def generate_gsm8k( model, tokenizer, tokenized_samples, batch_size, max_completion_length ): # run eval on main if dist.get_rank() == 0: device = model.device predictions = [] generation_config = transformers.GenerationConfig( max_new_tokens=max_completion_length, do_sample=False, repetition_penalty=1.0, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) model.eval() count = len(tokenized_samples) status = tqdm.tqdm(tokenized_samples, desc=f"Correct: 0/{count}") for i in range(0, count, batch_size): batches = tokenized_samples[i:i+batch_size] with torch.inference_mode(): longest = max(len(b[0]) for b in batches) # pad to longest on left side for decoder padded_input_ids = torch.stack([ torch.tensor([tokenizer.pad_token_id] * (longest - len(ids)) + ids) for ids, _ in batches ]).to(device) # ignore pad token when generating attn_mask = torch.stack([ tokens.ne(tokenizer.pad_token_id) for tokens in padded_input_ids ]).to(device) output = model.generate( input_ids=padded_input_ids, attention_mask=attn_mask, generation_config=generation_config, ) for i, generated in enumerate(output): response = tokenizer.decode( generated[len(padded_input_ids[i]) :], skip_special_tokens=True ) prediction = extract_xml_answer(response) predictions.append(batches[i][1] == prediction) status.update(batch_size) status.set_description(f"Correct: {sum(predictions)}/{count}") return np.mean(predictions) return 0 def tokenize_validation(tokenizer, samples, max_prompt_length): tokenized_samples = [] for sample in samples: prompt = sample["prompt"] answer = sample['answer'] ids = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, truncation=True, max_length=max_prompt_length, ) tokenized_samples.append((ids, answer)) return tokenized_samples class CustomTrainer(GRPOTrainer): def evaluate( self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval" ): tokenized_samples = tokenize_validation(self.processing_class, self.eval_dataset, self.args.max_prompt_length) eval_acc = generate_gsm8k(self.model, self.processing_class, tokenized_samples, self.args.per_device_eval_batch_size, self.args.max_completion_length) output = { f"{metric_key_prefix}_accuracy": eval_acc, "epoch": self.state.epoch, } self.log(output) self.control = self.callback_handler.on_evaluate( self.args, self.state, self.control, output ) return output training_args = GRPOConfig( output_dir=f"checkpoints/qwen25-05b", bf16=True, max_prompt_length=356, max_completion_length=512, learning_rate=learning_rate, ... rest of config, eval_steps=20, per_device_eval_batch_size=256, # adjust based on your GPU! may cause oom error do_eval=True, eval_strategy="steps" ) trainer = CustomTrainer( model=model, processing_class=tokenizer, reward_funcs=[ strict_format_reward_func, int_reward_func, correctness_reward_func, ], args=training_args, train_dataset=dataset, eval_dataset=test_dataset, )