Created
July 23, 2025 04:06
-
-
Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.
Revisions
-
huseinzol05 created this gist
Jul 23, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,70 @@ """ https://www.oumi.ai/docs/en/latest/_modules/oumi/core/callbacks/hf_mfu_callback.html """ import torch import wandb from transformers import TrainerCallback, TrainerState, TrainerControl # Theoretical Peak Tensor Core Performance (BF16 / FP16) DEVICES = { 'NVIDIA GeForce RTX 3090': 142, 'NVIDIA GeForce RTX 3090 Ti': 160, 'NVIDIA H100 80GB HBM3': 990, 'NVIDIA GH200 480GB': 990, } class WandbMFUCallback(TrainerCallback): def __init__(self, device_name = None): self._time_of_second_step: Optional[float] = None self._flops_at_second_step: Optional[float] = None self._time_for_train_steps = 0.0 self._first_step_finished = False self._device_name = device_name if self._device_name is None: self._device_name = torch.cuda.get_device_name() if self._device_name not in DEVICES: raise Exception('device name is not recognized.') def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): self._step_start_time = time.time() if not self._first_step_finished: return if self._time_of_second_step is None: self._time_of_second_step = self._step_start_time if state is not None: self._flops_at_second_step = state.total_flos def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): delta_time_seconds = time.time() - self._step_start_time if not self._first_step_finished: self._first_step_finished = True return self._time_for_train_steps += delta_time_seconds def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs): if self._time_of_second_step is None: return delta_time_seconds_train = time.time() - self._time_of_second_step delta_time_seconds_step = self._time_for_train_steps if self._flops_at_second_step is not None and ( state is not None and state.total_flos > 0.0 ): flops_since_second_step_on_all_devices = ( state.total_flos - self._flops_at_second_step ) flops_step = flops_since_second_step_on_all_devices / delta_time_seconds_step flops_train = flops_since_second_step_on_all_devices / delta_time_seconds_train device_flops_per_second = DEVICES[self._device_name] * 1e12 train_step_mfu = flops_step / device_flops_per_second train_mfu = flops_train / device_flops_per_second wandb.log({ "train_step_mfu": train_step_mfu, "train_mfu": train_mfu, }, step=state.global_step)