from torch.nn.utils import spectral_norm import torch.nn.functional as F import torch.nn as nn import torch class NonLocalBlock(nn.Module): """Non-local block.""" def __init__(self, conv_dim): super(NonLocalBlock, self).__init__() self.conv1 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//8, 1, 1, 0)) self.conv2 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//8, 1, 1, 0)) self.conv3 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//2, 1, 1, 0)) self.conv4 = spectral_norm(nn.Conv2d(conv_dim//2, conv_dim, 1, 1, 0)) self.downsample = nn.MaxPool2d(2, 2) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): N, C, H, W = x.size() # x: if (?, 1024, 8, 8) query = self.conv1(x) # (?, 128, 8, 8) query = query.reshape(N, C//8, -1) # (?, 128, 64) key = self.conv2(x) # (?, 128, 8, 8) key = self.downsample(key) # (?, 128, 4, 4) key = key.reshape(N, C//8, -1) # (?, 128, 16) attn = torch.bmm(query.transpose(1, 2), key) # (?, 64, 16) attn = F.softmax(attn, dim=2) # (?, 64, 16) value = self.conv3(x) # (?, 512, 8, 8) value = self.downsample(value) # (?, 512, 4, 4) value = value.reshape(N, C//2, -1) # (?, 512, 16) out = torch.bmm(value, attn.transpose(1, 2)) # (?, 512, 64) out = out.reshape(N, C//2, H, W) # (?, 512, 8, 8) return x + self.gamma * self.conv4(out) # (?, 1024, 8, 8)