import torch import torch.nn as nn import torchvision import torchvision.models.segmentation as models from functorch.version import __version__ as ft_version from functorch import make_functional_with_buffers, make_functional, grad tested_models = [] for model_name in models.__dict__: if model_name.startswith("_") or model_name[0].isupper(): continue if not callable(models.__dict__[model_name]): continue tested_models.append(model_name) criterion = nn.CrossEntropyLoss() def compute_grads(model, image, target): # Fix seed to fix dropout torch.manual_seed(0) output = model(image) loss = criterion(output["out"], target) loss.backward() device = 'cpu' def check_grads_model(model_name, device): batch_size = 8 torch.manual_seed(0) size = (224, 224) model = models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False) model = model.to(device) images = torch.randn(batch_size, 3, *size, device=device) targets = torch.randint(0, 10, (batch_size, ) + size, device=device) has_buffers = len(list(model.buffers())) > 0 gen_make_functional_fn = None if has_buffers: gen_make_functional_fn = make_functional_with_buffers else: gen_make_functional_fn = make_functional output = gen_make_functional_fn(model) if has_buffers: func_model, weights, buffers = output else: func_model, weights = output buffers = None def compute_loss_ft(weights, buffers, image, target): # Fix seed to fix dropout torch.manual_seed(0) if buffers is None: output = func_model(weights, image) else: output = func_model(weights, buffers, image) loss = criterion(output["out"], target) return loss compute_grad = grad(compute_loss_ft) w_grads = compute_grad(weights, buffers, images, targets) compute_grads(model, images, targets) assert len(w_grads) == len(list(model.parameters())) for wg, (n, p) in zip(w_grads, model.named_parameters()): assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}" print("") print("Torch:", torch.__version__) print("torchvision:", torchvision.__version__) print("Functorch:", ft_version) print("") for model_name in tested_models: print(f"-- Check {model_name} model") try: check_grads_model(model_name, device=device) except AssertionError as e: print(e)