Created
October 30, 2020 07:34
-
-
Save hieuphung97/353d9f990b89a9c08da9190a24d02317 to your computer and use it in GitHub Desktop.
Revisions
-
hieuphung97 created this gist
Oct 30, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,158 @@ # Source: https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F def one_hot(labels: torch.Tensor, num_classes: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, eps: Optional[float] = 1e-6) -> torch.Tensor: r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor. Args: labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`, where N is batch size. Each value is an integer representing correct classification. num_classes (int): number of classes in labels. device (Optional[torch.device]): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. dtype (Optional[torch.dtype]): the desired data type of returned tensor. Default: if None, infers data type from values. Returns: torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`, Examples:: >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) >>> kornia.losses.one_hot(labels, num_classes=3) tensor([[[[1., 0.], [0., 1.]], [[0., 1.], [0., 0.]], [[0., 0.], [1., 0.]]]] """ if not torch.is_tensor(labels): raise TypeError("Input labels type is not a torch.Tensor. Got {}" .format(type(labels))) if not labels.dtype == torch.int64: raise ValueError( "labels must be of the same dtype torch.int64. Got: {}" .format( labels.dtype)) if num_classes < 1: raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes)) shape = labels.shape one_hot = torch.zeros(shape[0], num_classes, *shape[1:], device=device, dtype=dtype) return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps # based on: # https://github.com/zhezh/focalloss/blob/master/focalloss.py def focal_loss( input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Focal loss. See :class:`~kornia.losses.FocalLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not len(input.shape) >= 2: raise ValueError("Invalid input shape, we expect BxCx*. Got: {}" .format(input.shape)) if input.size(0) != target.size(0): raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' .format(input.size(0), target.size(0))) n = input.size(0) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError('Expected target size {}, got {}'.format( out_size, target.size())) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}" .format( input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) + eps # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot( target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1., gamma) focal = -alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError("Invalid reduction mode: {}" .format(reduction)) return loss class FocalLoss(nn.Module): r"""Criterion that computes Focal loss. According to [1], the Focal loss is computed as follows: .. math:: \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) where: - :math:`p_t` is the model's estimated probability for each class. Arguments: alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. gamma (float): Focusing parameter :math:`\gamma >= 0`. reduction (str, optional): Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘none’. Shape: - Input: :math:`(N, C, *)` where C = number of classes. - Target: :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. Examples: >>> N = 5 # num_classes >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} >>> loss = kornia.losses.FocalLoss(**kwargs) >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = loss(input, target) >>> output.backward() References: [1] https://arxiv.org/abs/1708.02002 """ def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = 'none') -> None: super(FocalLoss, self).__init__() self.alpha: float = alpha self.gamma: float = gamma self.reduction: str = reduction self.eps: float = 1e-6 def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps)