Skip to content

Instantly share code, notes, and snippets.

@benwu232
Created August 25, 2017 06:50
Show Gist options
  • Select an option

  • Save benwu232/1fbf1cd6b637810f5d57902fa6d4ef1b to your computer and use it in GitHub Desktop.

Select an option

Save benwu232/1fbf1cd6b637810f5d57902fa6d4ef1b to your computer and use it in GitHub Desktop.

Revisions

  1. benwu232 created this gist Aug 25, 2017.
    50 changes: 50 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,50 @@
    def one_hot(size, index):
    """ Creates a matrix of one hot vectors.
    ```
    import torch
    import torch_extras
    setattr(torch, 'one_hot', torch_extras.one_hot)
    size = (3, 3)
    index = torch.LongTensor([2, 0, 1]).view(-1, 1)
    torch.one_hot(size, index)
    # [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
    ```
    """
    y_onehot = torch.LongTensor(*size).fill_(0)
    y_onehot = Variable(y_onehot, volatile=index.volatile)
    ones = Variable(torch.LongTensor(index.size()).fill_(1))
    y_onehot = y_onehot.scatter_(1, index.view(-1,1), ones.view(-1,1))
    return y_onehot

    #weight_matrix is an N*N matrix which describes the weights between classes
    class WeightMatrixLoss(torch.nn.Module):
    def __init__(self, weight_matrix=None):
    super().__init__()
    #self.register_buffer('weight_matrix', weight_matrix)
    self.weight_matrix = weight_matrix

    def forward(self, p_onehot, target):
    batch_size = len(target)

    target = target.cpu()
    t_onehot = one_hot(p_onehot.size(), target)
    t = t_onehot.unsqueeze(1).cuda()

    #p_onehot = p_onehot.cpu()
    p = p_onehot.unsqueeze(2)

    ce = -torch.bmm(t.float(), p)
    #ce = torch.squeeze(ce, 1)
    ce = ce.view((1, -1))

    _, predict_value = torch.max(p_onehot.data, 1)
    weight_line = np.zeros(batch_size, dtype=np.float32)
    #weight_matrix = self.weight_matrix.numpy()
    np_t = target.data.numpy()
    np_p = predict_value.cpu().view(-1).numpy()
    for k in range(batch_size):
    weight_line[k] = self.weight_matrix[np_t[k]][np_p[k]]
    weight_line = Variable(torch.from_numpy(weight_line).view((-1, 1))).cuda()

    wce = torch.mm(ce, weight_line).view(-1)
    return (wce / batch_size)