from collections import namedtuple import PIL from PIL import Image import torch import torch.utils.benchmark as benchmark import torchvision import torchvision.transforms as T import torchvision.transforms.functional as F BTransform = namedtuple("BTransform", ["op", "input_size", "name", "supported_dtypes"]) transforms = [ BTransform(op=T.Resize([256, 256], interpolation=T.InterpolationMode.BILINEAR), input_size=[500, 500], name=None, supported_dtypes=None), BTransform(op=T.RandomHorizontalFlip(p=1.0), input_size=[256, 256], name=None, supported_dtypes=None), BTransform(op=T.RandomResizedCrop(224), input_size=[500, 500], name=None, supported_dtypes=None), BTransform(op=T.autoaugment.RandAugment(), input_size=[224, 224], name=None, supported_dtypes=[torch.uint8, ]), # ImageNet train preset: BTransform(op=T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(p=0.5), lambda x: F.pil_to_tensor(x) if isinstance(x, PIL.Image.Image) else x, T.ConvertImageDtype(torch.float), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]), name="ImageNet train", input_size=[500, 500], supported_dtypes=[torch.uint8, ]), ] def run_bench(t): min_run_time = 2 transform = t.op label = transform.__class__.__name__ if t.name is None else t.name results = [] for dtype in [torch.uint8, torch.float32]: if t.supported_dtypes is not None and dtype not in t.supported_dtypes: continue c = 3 mode = "RGB" if dtype == torch.float32: c = 1 mode = "F" size = [c, ] + t.input_size tensor = torch.randint(0, 256, size=size, dtype=dtype) data = tensor.permute(1, 2, 0).contiguous().cpu().numpy() if dtype == torch.float32: pil_img = Image.fromarray(data[..., 0], mode=mode) else: pil_img = Image.fromarray(data, mode=mode) sub_label = f"{dtype} / {mode}" results += [ # With Pillow benchmark.Timer( stmt="t(x)", globals={ "x": pil_img, "t": transform, }, num_threads=torch.get_num_threads(), label=label, sub_label=sub_label, description=f"Transform on PIL", ).blocked_autorange(min_run_time=min_run_time), # With tensor benchmark.Timer( stmt="t(x)", globals={ "x": tensor, "t": transform, }, num_threads=torch.get_num_threads(), label=label, sub_label=sub_label, description=f"Transform on Tensor", ).blocked_autorange(min_run_time=min_run_time), ] return results def main(): all_results = [] for t in transforms: all_results += run_bench(t) compare = benchmark.Compare(all_results) compare.print() if __name__ == "__main__": print(f"Torch config: {torch.__config__.show()}") print(f"Num threads: {torch.get_num_threads()}") print(f"Torch version: {torch.__version__}") print(f"Torchvision version: {torchvision.__version__}") print(f"PIL version: {PIL.__version__}") main()