Skip to content

Instantly share code, notes, and snippets.

@vyraun
Forked from yzh119/st-gumbel.py
Created February 14, 2020 00:06
Show Gist options
  • Save vyraun/030be2e01889d4e94c0c905069f522d9 to your computer and use it in GitHub Desktop.
Save vyraun/030be2e01889d4e94c0c905069f522d9 to your computer and use it in GitHub Desktop.

Revisions

  1. @yzh119 yzh119 created this gist Jan 12, 2018.
    30 changes: 30 additions & 0 deletions st-gumbel.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,30 @@
    from __future__ import print_function
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable

    def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape).cuda()
    return -Variable(torch.log(-torch.log(U + eps) + eps))

    def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

    def gumbel_softmax(logits, temperature):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

    if __name__ == '__main__':
    import math
    print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0))