Skip to content

Instantly share code, notes, and snippets.

@pbamotra
pbamotra / sigsoftmax.py
Created April 10, 2019 04:08
Pytorch implementation of sigsoftmax - https://arxiv.org/pdf/1805.10829.pdf
def logsigsoftmax(logits):
"""
Computes sigsoftmax from the paper - https://arxiv.org/pdf/1805.10829.pdf
"""
max_values = torch.max(logits, 1, keepdim = True)[0]
exp_logits_sigmoided = torch.exp(logits - max_values) * torch.sigmoid(logits)
sum_exp_logits_sigmoided = exp_logits_sigmoided.sum(1, keepdim = True)
log_probs = logits - max_values + torch.log(torch.sigmoid(logits)) - torch.log(sum_exp_logits_sigmoided)
return log_probs