Skip to content

Instantly share code, notes, and snippets.

@abacaj
Last active February 25, 2025 22:52
Show Gist options
  • Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.
Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.

Revisions

  1. abacaj revised this gist Feb 5, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion train.py
    Original file line number Diff line number Diff line change
    @@ -75,7 +75,7 @@ def tokenize_validation(tokenizer, samples, max_prompt_length):
    ids = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    truncation=True,
    truncation=False,
    max_length=max_prompt_length,
    )
    tokenized_samples.append((ids, answer))
  2. abacaj revised this gist Feb 5, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion train.py
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,5 @@
    import tqdm
    import numpy as npy
    import numpy as np
    import torch
    import torch.distributed as dist
    import transformers
  3. abacaj revised this gist Feb 5, 2025. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -81,7 +81,7 @@ def tokenize_validation(tokenizer, samples, max_prompt_length):
    tokenized_samples.append((ids, answer))
    return tokenized_samples

    class CustomTrainer(GRPOTrainer):
    class CustomTrainer(transformers.GRPOTrainer):
    def evaluate(
    self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"
    ):
    @@ -100,7 +100,7 @@ def evaluate(

    return output

    training_args = GRPOConfig(
    training_args = transformers.GRPOConfig(
    output_dir=f"checkpoints/qwen25-05b",
    bf16=True,
    max_prompt_length=356,
  4. abacaj revised this gist Feb 5, 2025. 1 changed file with 3 additions and 0 deletions.
    3 changes: 3 additions & 0 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -113,6 +113,9 @@ def evaluate(
    eval_strategy="steps"
    )

    dataset = get_gsm8k_questions()
    test_dataset = get_gsm8k_questions("test")

    trainer = CustomTrainer(
    model=model,
    processing_class=tokenizer,
  5. abacaj created this gist Feb 5, 2025.
    127 changes: 127 additions & 0 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,127 @@
    import tqdm
    import numpy as npy
    import torch
    import torch.distributed as dist
    import transformers

    def extract_xml_answer(text: str) -> str:
    answer = text.split("<final_answer>")[-1]
    answer = answer.split("</final_answer>")[0]
    return answer.strip()

    def generate_gsm8k(
    model,
    tokenizer,
    tokenized_samples,
    batch_size,
    max_completion_length
    ):
    # run eval on main
    if dist.get_rank() == 0:
    device = model.device
    predictions = []
    generation_config = transformers.GenerationConfig(
    max_new_tokens=max_completion_length,
    do_sample=False,
    repetition_penalty=1.0,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    )
    model.eval()
    count = len(tokenized_samples)

    status = tqdm.tqdm(tokenized_samples, desc=f"Correct: 0/{count}")
    for i in range(0, count, batch_size):
    batches = tokenized_samples[i:i+batch_size]
    with torch.inference_mode():
    longest = max(len(b[0]) for b in batches)

    # pad to longest on left side for decoder
    padded_input_ids = torch.stack([
    torch.tensor([tokenizer.pad_token_id] * (longest - len(ids)) + ids)
    for ids, _ in batches
    ]).to(device)
    # ignore pad token when generating
    attn_mask = torch.stack([
    tokens.ne(tokenizer.pad_token_id) for tokens in padded_input_ids
    ]).to(device)

    output = model.generate(
    input_ids=padded_input_ids,
    attention_mask=attn_mask,
    generation_config=generation_config,
    )

    for i, generated in enumerate(output):
    response = tokenizer.decode(
    generated[len(padded_input_ids[i]) :], skip_special_tokens=True
    )

    prediction = extract_xml_answer(response)
    predictions.append(batches[i][1] == prediction)

    status.update(batch_size)
    status.set_description(f"Correct: {sum(predictions)}/{count}")

    return np.mean(predictions)

    return 0

    def tokenize_validation(tokenizer, samples, max_prompt_length):
    tokenized_samples = []
    for sample in samples:
    prompt = sample["prompt"]
    answer = sample['answer']
    ids = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    truncation=True,
    max_length=max_prompt_length,
    )
    tokenized_samples.append((ids, answer))
    return tokenized_samples

    class CustomTrainer(GRPOTrainer):
    def evaluate(
    self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"
    ):
    tokenized_samples = tokenize_validation(self.processing_class, self.eval_dataset, self.args.max_prompt_length)
    eval_acc = generate_gsm8k(self.model, self.processing_class, tokenized_samples, self.args.per_device_eval_batch_size, self.args.max_completion_length)

    output = {
    f"{metric_key_prefix}_accuracy": eval_acc,
    "epoch": self.state.epoch,
    }

    self.log(output)
    self.control = self.callback_handler.on_evaluate(
    self.args, self.state, self.control, output
    )

    return output

    training_args = GRPOConfig(
    output_dir=f"checkpoints/qwen25-05b",
    bf16=True,
    max_prompt_length=356,
    max_completion_length=512,
    learning_rate=learning_rate,
    ... rest of config,
    eval_steps=20,
    per_device_eval_batch_size=256, # adjust based on your GPU! may cause oom error
    do_eval=True,
    eval_strategy="steps"
    )

    trainer = CustomTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
    eval_dataset=test_dataset,
    )