Skip to content

Instantly share code, notes, and snippets.

@nasimrahaman
Last active November 16, 2023 04:54
Show Gist options
  • Select an option

  • Save nasimrahaman/a5fb23f096d7b0c3880e1622938d0901 to your computer and use it in GitHub Desktop.

Select an option

Save nasimrahaman/a5fb23f096d7b0c3880e1622938d0901 to your computer and use it in GitHub Desktop.

Revisions

  1. nasimrahaman revised this gist Jul 8, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion weighted_cross_entropy.py
    Original file line number Diff line number Diff line change
    @@ -21,7 +21,7 @@ def class_select(logits, target):
    one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
    .long()
    .repeat(batch_size, 1)
    .cuda(device_id=device)
    .cuda(device)
    .eq(target.data.repeat(num_classes, 1).t()))
    else:
    one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
  2. nasimrahaman created this gist Jul 8, 2017.
    64 changes: 64 additions & 0 deletions weighted_cross_entropy.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,64 @@
    import torch
    import torch.nn as nn


    def log_sum_exp(x):
    # See implementation detail in
    # http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
    # b is a shift factor. see link.
    # x.size() = [N, C]:
    b, _ = torch.max(x, 1)
    y = b + torch.log(torch.exp(x - b.expand_as(x)).sum(1))
    # y.size() = [N, 1]. Squeeze to [N] and return
    return y.squeeze(1)


    def class_select(logits, target):
    # in numpy, this would be logits[:, target].
    batch_size, num_classes = logits.size()
    if target.is_cuda:
    device = target.data.get_device()
    one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
    .long()
    .repeat(batch_size, 1)
    .cuda(device_id=device)
    .eq(target.data.repeat(num_classes, 1).t()))
    else:
    one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
    .long()
    .repeat(batch_size, 1)
    .eq(target.data.repeat(num_classes, 1).t()))
    return logits.masked_select(one_hot_mask)


    def cross_entropy_with_weights(logits, target, weights=None):
    assert logits.dim() == 2
    assert not target.requires_grad
    target = target.squeeze(1) if target.dim() == 2 else target
    assert target.dim() == 1
    loss = log_sum_exp(logits) - class_select(logits, target)
    if weights is not None:
    # loss.size() = [N]. Assert weights has the same shape
    assert list(loss.size()) == list(weights.size())
    # Weight the loss
    loss = loss * weights
    return loss


    class CrossEntropyLoss(nn.Module):
    """
    Cross entropy with instance-wise weights. Leave `aggregate` to None to obtain a loss
    vector of shape (batch_size,).
    """
    def __init__(self, aggregate='mean'):
    super(CrossEntropyLoss, self).__init__()
    assert aggregate in ['sum', 'mean', None]
    self.aggregate = aggregate

    def forward(self, input, target, weights=None):
    if self.aggregate == 'sum':
    return cross_entropy_with_weights(input, target, weights).sum()
    elif self.aggregate == 'mean':
    return cross_entropy_with_weights(input, target, weights).mean()
    elif self.aggregate is None:
    return cross_entropy_with_weights(input, target, weights)