Skip to content

Instantly share code, notes, and snippets.

@vyraun
Forked from ajbrock/Mixture_of_softmaxes.py
Created June 6, 2020 04:50
Show Gist options
  • Save vyraun/37121ad3e9bc24e306c00a15a9fa5e05 to your computer and use it in GitHub Desktop.
Save vyraun/37121ad3e9bc24e306c00a15a9fa5e05 to your computer and use it in GitHub Desktop.

Revisions

  1. @ajbrock ajbrock created this gist Nov 13, 2017.
    16 changes: 16 additions & 0 deletions Mixture_of_softmaxes.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,16 @@
    # PyTorch code For implementing the mixture of softmaxes layer from
    # "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model"
    # https://arxiv.org/abs/1711.03953
    context = self.fc(out)

    # Non-log version
    priors = F.softmax(context[:,-self.n_components:])
    mixtures = torch.stack([priors[:,i].unsqueeze(1) * F.softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(self.n_components)],1)
    out = torch.log(mixtures.sum(1))

    # Log version
    # log_priors = F.log_softmax(context[:,-self.num_components:]).unsqueeze(2)
    # log_mixtures = torch.stack([F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
    # log_priors = F.log_softmax(context[:,-self.num_components:])
    # log_mixtures = torch.stack([log_priors[:,i] + F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
    # out = torch.log(torch.exp(log_priors + log_mixtures).sum(1))