""" https://www.oumi.ai/docs/en/latest/_modules/oumi/core/callbacks/hf_mfu_callback.html """ import torch import wandb from transformers import TrainerCallback, TrainerState, TrainerControl # Theoretical Peak Tensor Core Performance (BF16 / FP16) DEVICES = { 'NVIDIA GeForce RTX 3090': 142, 'NVIDIA GeForce RTX 3090 Ti': 160, 'NVIDIA H100 80GB HBM3': 990, 'NVIDIA GH200 480GB': 990, } class WandbMFUCallback(TrainerCallback): def __init__(self, device_name = None): self._time_of_second_step: Optional[float] = None self._flops_at_second_step: Optional[float] = None self._time_for_train_steps = 0.0 self._first_step_finished = False self._device_name = device_name if self._device_name is None: self._device_name = torch.cuda.get_device_name() if self._device_name not in DEVICES: raise Exception('device name is not recognized.') def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): self._step_start_time = time.time() if not self._first_step_finished: return if self._time_of_second_step is None: self._time_of_second_step = self._step_start_time if state is not None: self._flops_at_second_step = state.total_flos def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): delta_time_seconds = time.time() - self._step_start_time if not self._first_step_finished: self._first_step_finished = True return self._time_for_train_steps += delta_time_seconds def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs): if self._time_of_second_step is None: return delta_time_seconds_train = time.time() - self._time_of_second_step delta_time_seconds_step = self._time_for_train_steps if self._flops_at_second_step is not None and ( state is not None and state.total_flos > 0.0 ): flops_since_second_step_on_all_devices = ( state.total_flos - self._flops_at_second_step ) flops_step = flops_since_second_step_on_all_devices / delta_time_seconds_step flops_train = flops_since_second_step_on_all_devices / delta_time_seconds_train device_flops_per_second = DEVICES[self._device_name] * 1e12 train_step_mfu = flops_step / device_flops_per_second train_mfu = flops_train / device_flops_per_second wandb.log({ "train_step_mfu": train_step_mfu, "train_mfu": train_mfu, }, step=state.global_step)