import numpy as np import itertools import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from torch.autograd import Variable from torch.autograd import Function from scipy.spatial.distance import pdist from core.config import cfg import nn as mynn import utils.net as net_utils import numpy as np import math import gc from joblib import Parallel,delayed from modeling.sparse_activations import Sparsemax from poincare_embeddings.hype import poincare from hyperbolic_cones import my_poincare_model as mpm DEBUG = False class GradScaler(Function): """ Gradient scaler layer Based off: https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4 """ def __init__(self, scaler=0.0): self.scaler = scaler def forward(self, x): return x.view_as(x) def backward(self, grad_output): return (grad_output * self.scaler) def grad_scale(x): return GradScaler()(x) class Poincare(nn.Module): def __init__(self): super(Poincare, self).__init__() self.eps = 1e-5 def forward(self, u, v): eps = self.eps squnorm = torch.clamp(torch.sum(u * u, dim=-1), 0, 1 - eps) sqvnorm = torch.clamp(torch.sum(v * v, dim=-1), 0, 1 - eps) sqdist = torch.sum(torch.pow(u - v, 2), dim=-1) #ctx.eps = eps #ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist) x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1 # arcosh z = torch.sqrt(torch.pow(x, 2) - 1) return torch.log(x + z) ##### Self-Attention Relation Networks Recreation ###### class SelfAttnMat(nn.Module): """ Visual appearance features to compute Self-attention Matrix """ def __init__(self, feat_dim=2048, proj_dim=256, T=1.0, use_poincare=False): super(SelfAttnMat, self).__init__() self.proj_dim = proj_dim self.T = T self.sparsemax = Sparsemax(dim=2) self.d_k_sqrt = math.sqrt(self.proj_dim) self.proj_w1 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) # [feat_dim x proj_dim] self.proj_w2 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) self._init_weights() self.use_poincare = use_poincare self.pm = poincare.PoincareManifold() self.my_pc = Poincare() # EDIT: scale and shift the distance on poincare disc self.poinc_scale = nn.Parameter(torch.tensor([1.0])) self.poinc_shift = nn.Parameter(torch.tensor([0.0])) def _init_weights(self): mynn.init.XavierFill(self.proj_w1.weight) init.constant_(self.proj_w1.bias, 0) mynn.init.XavierFill(self.proj_w2.weight) init.constant_(self.proj_w2.bias, 0) def forward(self, region_feature, num_imgs, iou_mat=[]): """ Return adjacency matrix as scaled dot-product self-attention """ # Send scaled (or zero) gradients to rest of net region_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(region_feature) # Project down the region features [n_img*n_region x n_dim x 1 x 1] feat_key = self.proj_w1(region_feature) feat_query = self.proj_w2(region_feature) # Reshape from (n_img*n_region, n_dim, 1, 1) to (n_img, n_region, n_dim, 1, 1) sz = feat_key.shape feat_key = feat_key.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3]) feat_query = feat_query.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3]) use_poincare = self.use_poincare if use_poincare: import time;start = time.time() device_id = feat_key.get_device() n_img = feat_key.shape[0] n_region = feat_key.shape[1] R_new = [] #A = torch.zeros((n_region,n_region)) for im in range(n_img): u = feat_key[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim] v = feat_query[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim] # normalize to unit ball u = F.normalize(u,p=1,dim=1) v = F.normalize(v,p=1,dim=1) # slow version: iterate through regions for i in range(n_region): A[i,:] = self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) # slow version sped up with multiprocessing -- pytorch DataLoader threads cry #pool = Parallel(n_jobs=2)( # delayed(self.pm.distance)(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) for i in range(n_region) # ) #A = [(i,self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)) for i in range(n_region)] #import pdb; pdb.set_trace(); # broadcast version -- poincare grad throws an error #A = self.pm.distance(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1)) # broadcast with my Poincare (above): directly uses torch autograd, not the poincare grad #A = self.my_pc(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1)) R_new.append(A.unsqueeze(0)) del A R_new = torch.cat(R_new).cuda(device_id) # [n_img x n_region x n_region] R_new = (1 - R_new) end = time.time(); print('==',end - start) else: feat_key = feat_key.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1] feat_query = feat_query.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1] feat_query = feat_query.transpose(1, 2) # [n_img x 1 x n_region x n_dim x 1 x 1] # broadcast: [n_img x n_region x n_region x n_dim x 1 x 1] R_new = feat_key * feat_query R_new = R_new.squeeze(-1).squeeze(-1) # [n_img x n_region x n_region x n_dim] R_new = R_new.sum(3) / self.d_k_sqrt # self-attention/relation network has sqrt(d_k) if cfg.TRAIN.ATTN_IOU_THRESH: # mask out region pairs with IoU > TRAIN.IOU_THRESH assert num_imgs == 1 # TODO - extend to multiple images assert len(iou_mat) > 0 device_id = R_new.get_device() mask = (iou_mat < cfg.TRAIN.IOU_THRESH).astype('float32') np.fill_diagonal(mask, 1.0) mask = Variable(torch.from_numpy(mask), requires_grad=False).cuda(device_id) mask = mask.view(R_new.shape) R_new = R_new * mask R_new = R_new.contiguous() R_new = R_new / (self.T) # softmax temperature if cfg.TRAIN.SPARSEMAX: out = self.sparsemax(R_new) else: if cfg.TRAIN.DROPOUT > 0: R_new = F.dropout(R_new,p=cfg.TRAIN.DROPOUT,inplace=True) if use_poincare: #out = (R_new / R_new.sum(dim=2)) out = (R_new / R_new.sum(dim=2).unsqueeze(2)) #import pdb; pdb.set_trace(); # EDIT: If you are using "softmax" poincare # R_new = (-self.poinc_scale * R_new) + self.poinc_shift # Then do softmax instead of row-sum 1 else: out = F.softmax(R_new, 2) return out class SelfAttn_basic(nn.Module): """ Self Attention Network for visual context """ def __init__(self, num_A=1, feat_dim=2048, input_feat=2048, output_feat=2048, visual_proj_dim=256, combine='add'): """ num_A - Number of adjacency matrices (multi attention heads) feat_dim - Size of "appearance" features for each region (roi) input_feat - Size of ROI-pooled features (can be different from feat_dim) output_features - Size of the output from each attention head """ super(SelfAttn_basic, self).__init__() self.num_A = num_A self.proj_dim = visual_proj_dim self.output_feat = int(output_feat / self.num_A) self.combine = combine if cfg.TRAIN.CONTEXT_BBOX: feat_dim += 4 # bbox coords are appended to visual feat if self.num_A >= 1: assert cfg.TRAIN.ATTN_W # multi-heads need down-projection with W for i in range(self.num_A): module_AdjMat = SelfAttnMat(feat_dim=feat_dim, proj_dim=self.proj_dim, T=cfg.TRAIN.SOFTMAX_T[i], use_poincare=(i in cfg.TRAIN.POINCARE)) self.add_module('compute_AdjMat{}'.format(i), module_AdjMat) linear_out = nn.Conv2d(input_feat, self.output_feat, 1, stride=1) self._init_weights_multi(linear_out) self.add_module('linear_out{}'.format(i), linear_out) else: raise ValueError def _init_weights_multi(self, linear_out): mynn.init.XavierFill(linear_out.weight) init.constant_(linear_out.bias, 0) def forward(self, visual_feature, x, num_imgs, iou_mat=[], bboxes=[]): """ Returns features incorporating visual context from all other rois visual_feature - appearance feature tensor [num_rois, feat_dim, 1, 1] x - region (box) feature [num_rois, feat_dim, 1, 1] num_imgs - number of images per batch iou_mat - (optional) IoU between regions [num_rois, num_rois] Visual feature and "x" can be from same or different CNN layers. Image IDs are typically obtained from rpn_ret['rois'] in model_builder.py """ # Scale-down or zero-out gradients to rest of the (pre-trained) network x = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(x) num_rois = int(visual_feature.shape[0] / num_imgs) if cfg.TRAIN.ATTN_IOU_THRESH: assert len(iou_mat) > 0 if cfg.TRAIN.CONTEXT_BBOX: bboxes = bboxes.unsqueeze(-1).unsqueeze(-1).cuda(visual_feature.get_device()) visual_feature = torch.cat((visual_feature, bboxes), 1) visual_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(visual_feature) z = [] for i in range(self.num_A): # z = A.X.W A_i = self._modules['compute_AdjMat{}'.format(i)](visual_feature, num_imgs, iou_mat=iou_mat) z_i = torch.bmm(A_i, x.view(num_imgs, num_rois, -1, 1, 1).squeeze(-1).squeeze(-1)) z_i = z_i.view(num_imgs*num_rois, -1) z_i = z_i.unsqueeze(-1).unsqueeze(-1) z_i = self._modules['linear_out{}'.format(i)](z_i) # [n_img*n_region, output_features, 1, 1] z.append(z_i) z = torch.cat(z, 1) # [n_img*n_region, num_A*output_features, 1, 1] if self.combine == 'add': y = x + z elif self.combine == 'concat': y = torch.cat([x,z], 1) else: raise NotImplementedError y = F.relu(y, inplace=True) return y ##### END: Self-Attention Relation Networks Recreation ###### def _gen_timing_signal(length, channels=64, min_timescale=1.0, max_timescale=1.0e3): """ Generates a [1, length, channels] timing signal consisting of sinusoids Adapted from: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py """ position = np.arange(length) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) inv_timescales = min_timescale * np.exp( np.arange(num_timescales).astype(np.float) * -log_timescale_increment) scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.pad(signal, [[0, 0], [0, channels % 2]], 'constant', constant_values=[0.0, 0.0]) signal = signal.reshape([1, length, channels]) return torch.from_numpy(signal).type(torch.FloatTensor)