# PyTorch code For implementing the mixture of softmaxes layer from # "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model" # https://arxiv.org/abs/1711.03953 context = self.fc(out) # Non-log version priors = F.softmax(context[:,-self.n_components:]) mixtures = torch.stack([priors[:,i].unsqueeze(1) * F.softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(self.n_components)],1) out = torch.log(mixtures.sum(1)) # Log version # log_priors = F.log_softmax(context[:,-self.num_components:]).unsqueeze(2) # log_mixtures = torch.stack([F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1) # log_priors = F.log_softmax(context[:,-self.num_components:]) # log_mixtures = torch.stack([log_priors[:,i] + F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1) # out = torch.log(torch.exp(log_priors + log_mixtures).sum(1))