Skip to content

Instantly share code, notes, and snippets.

@huseinzol05
Created July 23, 2025 04:06
Show Gist options
  • Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.
Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.

Revisions

  1. huseinzol05 created this gist Jul 23, 2025.
    70 changes: 70 additions & 0 deletions hf_mfu.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,70 @@
    """
    https://www.oumi.ai/docs/en/latest/_modules/oumi/core/callbacks/hf_mfu_callback.html
    """

    import torch
    import wandb
    from transformers import TrainerCallback, TrainerState, TrainerControl

    # Theoretical Peak Tensor Core Performance (BF16 / FP16)
    DEVICES = {
    'NVIDIA GeForce RTX 3090': 142,
    'NVIDIA GeForce RTX 3090 Ti': 160,
    'NVIDIA H100 80GB HBM3': 990,
    'NVIDIA GH200 480GB': 990,
    }

    class WandbMFUCallback(TrainerCallback):
    def __init__(self, device_name = None):
    self._time_of_second_step: Optional[float] = None
    self._flops_at_second_step: Optional[float] = None
    self._time_for_train_steps = 0.0
    self._first_step_finished = False
    self._device_name = device_name
    if self._device_name is None:
    self._device_name = torch.cuda.get_device_name()
    if self._device_name not in DEVICES:
    raise Exception('device name is not recognized.')

    def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
    self._step_start_time = time.time()
    if not self._first_step_finished:
    return

    if self._time_of_second_step is None:
    self._time_of_second_step = self._step_start_time
    if state is not None:
    self._flops_at_second_step = state.total_flos

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
    delta_time_seconds = time.time() - self._step_start_time
    if not self._first_step_finished:
    self._first_step_finished = True
    return
    self._time_for_train_steps += delta_time_seconds

    def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs):

    if self._time_of_second_step is None:
    return

    delta_time_seconds_train = time.time() - self._time_of_second_step
    delta_time_seconds_step = self._time_for_train_steps

    if self._flops_at_second_step is not None and (
    state is not None and state.total_flos > 0.0
    ):
    flops_since_second_step_on_all_devices = (
    state.total_flos - self._flops_at_second_step
    )
    flops_step = flops_since_second_step_on_all_devices / delta_time_seconds_step
    flops_train = flops_since_second_step_on_all_devices / delta_time_seconds_train
    device_flops_per_second = DEVICES[self._device_name] * 1e12
    train_step_mfu = flops_step / device_flops_per_second
    train_mfu = flops_train / device_flops_per_second

    wandb.log({
    "train_step_mfu": train_step_mfu,
    "train_mfu": train_mfu,
    }, step=state.global_step)