def index(x, axis, idxs): """ Inputs: - x: torch.Tensor with x.dim() == N - axis: Integer with 0 <= axis < N - idxs: List of integers, with 0 <= idxs[i] < x.size(axis) Returns: y: torch.Tensor satisfying y.select(axis, i) == x.select(axis, index[i]) """ view_size = [1] * x.dim() view_size[axis] = len(idxs) view_size = torch.Size(view_size) expand_size = list(x.size()) expand_size[axis] = len(idxs) expand_size = torch.Size(expand_size) idxs = torch.LongTensor(idxs).view(view_size).expand(expand_size) return x.gather(axis, idxs)