""" gresnet.py: (Good/Great/Godlike/Gangster ResNet) Implementation adapted from torchvision ResNet50 v1.4. """ import math from typing import Any, Callable, Optional, Type, Tuple, Union from torch import Tensor import torch import torch.nn as nn from ..utils import _log_api_usage_once from ._api import register_model, WeightsEnum from ._utils import _ovewrite_named_param __all__ = [ "GResNet", #"resnetd50", "resnete50", "gresnet50", "gresnet101", "gresnet152", ] class _Affine(nn.Module): def __init__(self, num_features: int, spatial_dims: int) -> None: super().__init__() self.num_features = num_features self.spatial_dims = spatial_dims dims = (num_features,) + (1,) * spatial_dims self.gamma = nn.Parameter(torch.empty(dims)) self.beta = nn.Parameter(torch.empty(dims)) def forward(self, x: Tensor) -> Tensor: x = x * self.gamma + self.beta return x class Affine1d(_Affine): def __init__(self, num_features: int) -> None: super().__init__(num_features, spatial_dims=1) def extra_repr(self) -> str: return "{num_features}".format(**self.__dict__) class Affine2d(_Affine): def __init__(self, num_features: int) -> None: super().__init__(num_features, spatial_dims=2) def extra_repr(self) -> str: return "{num_features}".format(**self.__dict__) def conv_blk( in_planes: int, out_planes: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, norm_type: Optional[str] = "batch", act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU ) -> nn.Sequential: layers = [] # pre-conv norm if norm_type == "split": layers.append(nn.BatchNorm2d(in_planes, affine=False)) # convolution layers.append(nn.Conv2d( in_planes, out_planes, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False )) # post-conv norm if norm_type in ["split", "affine"]: layers.append(Affine2d(out_planes)) elif norm_type == "batch": layers.append(nn.BatchNorm2d(out_planes, affine=True)) # activation if act_layer: layers.append(act_layer(inplace=True)) return nn.Sequential(*layers) class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, in_planes: int, planes: int, stride: int = 1, groups: int = 1, base_width: int = 64, downsample: Optional[nn.Module] = None, norm_type: Optional[str] = "batch", act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU ) -> None: super().__init__() self.stride = stride self.downsample = downsample width = int(planes * (base_width / 64.0)) * groups self.conv1 = conv_blk( in_planes, width, 1, norm_type=norm_type, act_layer=act_layer ) self.conv2 = conv_blk( width, width, 3, stride=stride, padding=1, groups=groups, norm_type=norm_type, act_layer=act_layer ) self.conv3 = conv_blk( width, planes * self.expansion, 1, norm_type=norm_type, act_layer=None ) self.relu = act_layer(inplace=True) def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.downsample: identity = self.downsample(identity) out += identity out = self.relu(out) return out class GResNet(nn.Module): def __init__( self, block: Type[Union[Bottleneck]], layers: Tuple[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, in_planes: int = 128, norm_type: Optional[str] = "batch", act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU ) -> None: super().__init__() _log_api_usage_once(self) self._groups = groups self._in_planes = in_planes self._norm_type = norm_type self._act_layer = act_layer self.base_width = width_per_group self.stem = self._make_stem(stem_type="tiered", use_pool=False) self.layer1 = self._make_layer(block, 64, layers[0], stride=2) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.classifier = self._make_classifier(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): if m.affine: nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, (Affine1d, Affine2d)): nn.init.constant_(m.gamma, 1) nn.init.constant_(m.beta, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck) and m.bn3.weight is not None: nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] def _make_stem( self, stem_type: str = "tiered", use_pool: bool = False ) -> nn.Sequential: layers = [] if stem_type == "tiered": layers.append(conv_blk( 3, self._in_planes // 4, 3, stride=2, padding=1, norm_type="affine", act_layer=self._act_layer )) layers.append(conv_blk( self._in_planes // 4, self._in_planes // 2, 3, padding=1, norm_type=self._norm_type, act_layer=self._act_layer )) layers.append(conv_blk( self._in_planes // 2, self._in_planes, 3, padding=1, norm_type=self._norm_type, act_layer=self._act_layer )) else: layers.append(conv_blk( 3, self._in_planes, 7, stride=2, padding=3, norm_type="affine", act_layer=self._act_layer )) if use_pool: layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) return nn.Sequential(*layers) def _make_layer( self, block: Type[Bottleneck], planes: int, blocks: int, stride: int = 1 ) -> nn.Sequential: downsample = None if stride != 1 or self._in_planes != planes * block.expansion: downsample = nn.Sequential( nn.AvgPool2d(stride, stride=stride, ceil_mode=True), conv_blk( self._in_planes, planes * block.expansion, 1, stride=1, groups=self._groups, norm_type=self._norm_type, act_layer=None ) ) layers = [] layers.append( block( self._in_planes, planes, stride=stride, downsample=downsample, groups=self._groups, base_width=self.base_width, norm_type=self._norm_type, act_layer=self._act_layer ) ) self._in_planes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self._in_planes, planes, groups=self._groups, base_width=self.base_width, norm_type=self._norm_type, act_layer=self._act_layer ) ) return nn.Sequential(*layers) def _make_classifier( self, num_features: int, num_classes: Optional[int] = 1000 ) -> nn.Sequential: layers = [] layers.append(nn.AdaptiveAvgPool2d((1, 1))) layers.append(nn.Flatten()) layers.append(nn.Linear(num_features, num_classes)) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.classifier(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _gresnet( block: Type[Bottleneck], layers: Tuple[int], weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> GResNet: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = GResNet(block, layers, **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) return model # this is not quite ResNet-D (the stem is different) #@register_model() #def resnetd50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: # return _gresnet( # Bottleneck, # [3, 4, 6, 3], # weights, # progress, # **kwargs # ) @register_model() def resnete50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: return _gresnet( Bottleneck, (3, 4, 6, 3), weights, progress, act_layer=nn.SiLU, **kwargs ) @register_model() def gresnet50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: return _gresnet( Bottleneck, (3, 4, 6, 3), weights, progress, act_layer=nn.SiLU, norm_type="split", **kwargs ) @register_model() def gresnet101(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: return _gresnet( Bottleneck, (3, 4, 23, 3), weights, progress, act_layer=nn.SiLU, norm_type="split", **kwargs ) @register_model() def gresnet152(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: return _gresnet( Bottleneck, (3, 8, 36, 3), weights, progress, act_layer=nn.SiLU, norm_type="split", **kwargs )