import torch from torch.utils.flop_counter import FlopCounterMode from triton.testing import do_bench def get_flops_achieved(f): flop_counter = FlopCounterMode(display=False) with flop_counter: f() total_flops = flop_counter.get_total_flops() ms_per_iter = do_bench(f) iters_per_second = 1e3/ms_per_iter print(f"{iters_per_second * total_flops / 1e12} TF/s") from torchvision.models import resnet18 model = resnet18().cuda().half() inp = torch.randn(128, 3, 224, 224, device='cuda', dtype=torch.half) get_flops_achieved(lambda: model(inp).sum().backward())