def forward(self, x): x = self.features(x) [bs, ch, h, w] = x.shape x = x.view(bs, ch, -1).transpose(2, 1) # x.register_hook(self.save_grad('x')) # Gram Matrix NxN for the N input features "x" K = x.bmm(x.transpose(2, 1)) K = x * x; # < --- IS THIS CORRECT for 1st order features???? alpha = torch.autograd.Variable(torch.ones(bs, h*w, 1)).cuda() Ci = torch.sum(K, 2, keepdim=True) mask = Ci < 1e-10 mask = mask.detach() Ci = torch.pow(Ci, self.gamma) Ci[mask] = 0 Ci = Ci.detach() # Sinkhorn iterations for _ in range(10): alpha = torch.pow(alpha + 1e-10, 1-self.sinkhorn_t) / \ (torch.pow(K.bmm(alpha) + 1e-10, self.sinkhorn_t) + 1e-10) # x = x * torch.pow(alpha, 0.5) # x = x.transpose(1, 2).bmm(x).view(bs, -1) # EDIT THIS OUT FOR FIRST ORDER ???? x = x * alpha x = torch.sqrt(x + 1e-8) x = torch.nn.functional.normalize(x) x = self.fc(x) return x