Skip to content

Instantly share code, notes, and snippets.

@kumar-abhishek
Forked from erogol/NLL_OHEM.py
Created January 29, 2023 21:02
Show Gist options
  • Save kumar-abhishek/355de395d9db92f8db5a2ede02e40a6b to your computer and use it in GitHub Desktop.
Save kumar-abhishek/355de395d9db92f8db5a2ede02e40a6b to your computer and use it in GitHub Desktop.

Revisions

  1. @erogol erogol revised this gist Oct 23, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion NLL_OHEM.py
    Original file line number Diff line number Diff line change
    @@ -17,7 +17,7 @@ def forward(self, x, y, ratio=None):
    x_ = x.clone()
    inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()
    for idx, label in enumerate(y.data):
    inst_losses[idx] = x_.data[idx, label]
    inst_losses[idx] = -x_.data[idx, label]
    #loss_incs = -x_.sum(1)
    _, idxs = inst_losses.topk(num_hns)
    x_hn = x.index_select(0, idxs)
  2. @erogol erogol created this gist Oct 22, 2017.
    25 changes: 25 additions & 0 deletions NLL_OHEM.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,25 @@
    import torch as th


    class NLL_OHEM(th.nn.NLLLoss):
    """ Online hard example mining.
    Needs input from nn.LogSotmax() """

    def __init__(self, ratio):
    super(NLL_OHEM, self).__init__(None, True)
    self.ratio = ratio

    def forward(self, x, y, ratio=None):
    if ratio is not None:
    self.ratio = ratio
    num_inst = x.size(0)
    num_hns = int(self.ratio * num_inst)
    x_ = x.clone()
    inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()
    for idx, label in enumerate(y.data):
    inst_losses[idx] = x_.data[idx, label]
    #loss_incs = -x_.sum(1)
    _, idxs = inst_losses.topk(num_hns)
    x_hn = x.index_select(0, idxs)
    y_hn = y.index_select(0, idxs)
    return th.nn.functional.nll_loss(x_hn, y_hn)