Last active
November 16, 2023 04:54
-
-
Save nasimrahaman/a5fb23f096d7b0c3880e1622938d0901 to your computer and use it in GitHub Desktop.
Revisions
-
nasimrahaman revised this gist
Jul 8, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -21,7 +21,7 @@ def class_select(logits, target): one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) .long() .repeat(batch_size, 1) .cuda(device) .eq(target.data.repeat(num_classes, 1).t())) else: one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) -
nasimrahaman created this gist
Jul 8, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,64 @@ import torch import torch.nn as nn def log_sum_exp(x): # See implementation detail in # http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # b is a shift factor. see link. # x.size() = [N, C]: b, _ = torch.max(x, 1) y = b + torch.log(torch.exp(x - b.expand_as(x)).sum(1)) # y.size() = [N, 1]. Squeeze to [N] and return return y.squeeze(1) def class_select(logits, target): # in numpy, this would be logits[:, target]. batch_size, num_classes = logits.size() if target.is_cuda: device = target.data.get_device() one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) .long() .repeat(batch_size, 1) .cuda(device_id=device) .eq(target.data.repeat(num_classes, 1).t())) else: one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) .long() .repeat(batch_size, 1) .eq(target.data.repeat(num_classes, 1).t())) return logits.masked_select(one_hot_mask) def cross_entropy_with_weights(logits, target, weights=None): assert logits.dim() == 2 assert not target.requires_grad target = target.squeeze(1) if target.dim() == 2 else target assert target.dim() == 1 loss = log_sum_exp(logits) - class_select(logits, target) if weights is not None: # loss.size() = [N]. Assert weights has the same shape assert list(loss.size()) == list(weights.size()) # Weight the loss loss = loss * weights return loss class CrossEntropyLoss(nn.Module): """ Cross entropy with instance-wise weights. Leave `aggregate` to None to obtain a loss vector of shape (batch_size,). """ def __init__(self, aggregate='mean'): super(CrossEntropyLoss, self).__init__() assert aggregate in ['sum', 'mean', None] self.aggregate = aggregate def forward(self, input, target, weights=None): if self.aggregate == 'sum': return cross_entropy_with_weights(input, target, weights).sum() elif self.aggregate == 'mean': return cross_entropy_with_weights(input, target, weights).mean() elif self.aggregate is None: return cross_entropy_with_weights(input, target, weights)