Skip to content

Instantly share code, notes, and snippets.

@fzliu
Created March 25, 2024 12:42
Show Gist options
  • Save fzliu/006d2043dc1e90d68ae562c5bde8066c to your computer and use it in GitHub Desktop.
Save fzliu/006d2043dc1e90d68ae562c5bde8066c to your computer and use it in GitHub Desktop.

Revisions

  1. fzliu created this gist Mar 25, 2024.
    417 changes: 417 additions & 0 deletions gresnet.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,417 @@
    """
    gresnet.py: (Good/Great/Godlike/Gangster ResNet)
    Implementation adapted from torchvision ResNet50 v1.4.
    """

    import math
    from typing import Any, Callable, Optional, Type, Tuple, Union

    from torch import Tensor
    import torch
    import torch.nn as nn

    from ..utils import _log_api_usage_once
    from ._api import register_model, WeightsEnum
    from ._utils import _ovewrite_named_param


    __all__ = [
    "GResNet",
    #"resnetd50",
    "resnete50",
    "gresnet50",
    "gresnet101",
    "gresnet152",
    ]


    class _Affine(nn.Module):
    def __init__(self, num_features: int, spatial_dims: int) -> None:
    super().__init__()
    self.num_features = num_features
    self.spatial_dims = spatial_dims
    dims = (num_features,) + (1,) * spatial_dims
    self.gamma = nn.Parameter(torch.empty(dims))
    self.beta = nn.Parameter(torch.empty(dims))

    def forward(self, x: Tensor) -> Tensor:
    x = x * self.gamma + self.beta
    return x


    class Affine1d(_Affine):
    def __init__(self, num_features: int) -> None:
    super().__init__(num_features, spatial_dims=1)

    def extra_repr(self) -> str:
    return "{num_features}".format(**self.__dict__)


    class Affine2d(_Affine):
    def __init__(self, num_features: int) -> None:
    super().__init__(num_features, spatial_dims=2)

    def extra_repr(self) -> str:
    return "{num_features}".format(**self.__dict__)


    def conv_blk(
    in_planes: int,
    out_planes: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
    groups: int = 1,
    norm_type: Optional[str] = "batch",
    act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
    ) -> nn.Sequential:
    layers = []

    # pre-conv norm
    if norm_type == "split":
    layers.append(nn.BatchNorm2d(in_planes, affine=False))

    # convolution
    layers.append(nn.Conv2d(
    in_planes,
    out_planes,
    kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
    bias=False
    ))

    # post-conv norm
    if norm_type in ["split", "affine"]:
    layers.append(Affine2d(out_planes))
    elif norm_type == "batch":
    layers.append(nn.BatchNorm2d(out_planes, affine=True))

    # activation
    if act_layer:
    layers.append(act_layer(inplace=True))

    return nn.Sequential(*layers)


    class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
    self,
    in_planes: int,
    planes: int,
    stride: int = 1,
    groups: int = 1,
    base_width: int = 64,
    downsample: Optional[nn.Module] = None,
    norm_type: Optional[str] = "batch",
    act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
    ) -> None:
    super().__init__()
    self.stride = stride
    self.downsample = downsample

    width = int(planes * (base_width / 64.0)) * groups

    self.conv1 = conv_blk(
    in_planes,
    width,
    1,
    norm_type=norm_type,
    act_layer=act_layer
    )
    self.conv2 = conv_blk(
    width,
    width,
    3,
    stride=stride,
    padding=1,
    groups=groups,
    norm_type=norm_type,
    act_layer=act_layer
    )
    self.conv3 = conv_blk(
    width,
    planes * self.expansion,
    1,
    norm_type=norm_type,
    act_layer=None
    )

    self.relu = act_layer(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
    identity = x

    out = self.conv1(x)
    out = self.conv2(out)
    out = self.conv3(out)

    if self.downsample:
    identity = self.downsample(identity)

    out += identity
    out = self.relu(out)

    return out


    class GResNet(nn.Module):
    def __init__(
    self,
    block: Type[Union[Bottleneck]],
    layers: Tuple[int],
    num_classes: int = 1000,
    zero_init_residual: bool = False,
    groups: int = 1,
    width_per_group: int = 64,
    in_planes: int = 128,
    norm_type: Optional[str] = "batch",
    act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
    ) -> None:
    super().__init__()
    _log_api_usage_once(self)
    self._groups = groups
    self._in_planes = in_planes
    self._norm_type = norm_type
    self._act_layer = act_layer

    self.base_width = width_per_group
    self.stem = self._make_stem(stem_type="tiered", use_pool=False)
    self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    self.classifier = self._make_classifier(512 * block.expansion, num_classes)

    for m in self.modules():
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
    if m.affine:
    nn.init.constant_(m.weight, 1)
    nn.init.constant_(m.bias, 0)
    elif isinstance(m, (Affine1d, Affine2d)):
    nn.init.constant_(m.gamma, 1)
    nn.init.constant_(m.beta, 0)

    # Zero-initialize the last BN in each residual branch,
    # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    if zero_init_residual:
    for m in self.modules():
    if isinstance(m, Bottleneck) and m.bn3.weight is not None:
    nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]

    def _make_stem(
    self,
    stem_type: str = "tiered",
    use_pool: bool = False
    ) -> nn.Sequential:
    layers = []

    if stem_type == "tiered":
    layers.append(conv_blk(
    3,
    self._in_planes // 4,
    3,
    stride=2,
    padding=1,
    norm_type="affine",
    act_layer=self._act_layer
    ))
    layers.append(conv_blk(
    self._in_planes // 4,
    self._in_planes // 2,
    3,
    padding=1,
    norm_type=self._norm_type,
    act_layer=self._act_layer
    ))
    layers.append(conv_blk(
    self._in_planes // 2,
    self._in_planes,
    3,
    padding=1,
    norm_type=self._norm_type,
    act_layer=self._act_layer
    ))
    else:
    layers.append(conv_blk(
    3,
    self._in_planes,
    7,
    stride=2,
    padding=3,
    norm_type="affine",
    act_layer=self._act_layer
    ))

    if use_pool:
    layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    return nn.Sequential(*layers)

    def _make_layer(
    self,
    block: Type[Bottleneck],
    planes: int,
    blocks: int,
    stride: int = 1
    ) -> nn.Sequential:
    downsample = None
    if stride != 1 or self._in_planes != planes * block.expansion:
    downsample = nn.Sequential(
    nn.AvgPool2d(stride, stride=stride, ceil_mode=True),
    conv_blk(
    self._in_planes,
    planes * block.expansion,
    1,
    stride=1,
    groups=self._groups,
    norm_type=self._norm_type,
    act_layer=None
    )
    )

    layers = []
    layers.append(
    block(
    self._in_planes,
    planes,
    stride=stride,
    downsample=downsample,
    groups=self._groups,
    base_width=self.base_width,
    norm_type=self._norm_type,
    act_layer=self._act_layer
    )
    )
    self._in_planes = planes * block.expansion
    for _ in range(1, blocks):
    layers.append(
    block(
    self._in_planes,
    planes,
    groups=self._groups,
    base_width=self.base_width,
    norm_type=self._norm_type,
    act_layer=self._act_layer
    )
    )

    return nn.Sequential(*layers)

    def _make_classifier(
    self,
    num_features: int,
    num_classes: Optional[int] = 1000
    ) -> nn.Sequential:
    layers = []
    layers.append(nn.AdaptiveAvgPool2d((1, 1)))
    layers.append(nn.Flatten())
    layers.append(nn.Linear(num_features, num_classes))
    return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
    x = self.stem(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.classifier(x)
    return x

    def forward(self, x: Tensor) -> Tensor:
    return self._forward_impl(x)


    def _gresnet(
    block: Type[Bottleneck],
    layers: Tuple[int],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
    ) -> GResNet:
    if weights is not None:
    _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = GResNet(block, layers, **kwargs)

    if weights is not None:
    model.load_state_dict(weights.get_state_dict(progress=progress))

    return model


    # this is not quite ResNet-D (the stem is different)
    #@register_model()
    #def resnetd50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
    # return _gresnet(
    # Bottleneck,
    # [3, 4, 6, 3],
    # weights,
    # progress,
    # **kwargs
    # )


    @register_model()
    def resnete50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
    return _gresnet(
    Bottleneck,
    (3, 4, 6, 3),
    weights,
    progress,
    act_layer=nn.SiLU,
    **kwargs
    )


    @register_model()
    def gresnet50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
    return _gresnet(
    Bottleneck,
    (3, 4, 6, 3),
    weights,
    progress,
    act_layer=nn.SiLU,
    norm_type="split",
    **kwargs
    )


    @register_model()
    def gresnet101(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
    return _gresnet(
    Bottleneck,
    (3, 4, 23, 3),
    weights,
    progress,
    act_layer=nn.SiLU,
    norm_type="split",
    **kwargs
    )


    @register_model()
    def gresnet152(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
    return _gresnet(
    Bottleneck,
    (3, 8, 36, 3),
    weights,
    progress,
    act_layer=nn.SiLU,
    norm_type="split",
    **kwargs
    )