Skip to content

Instantly share code, notes, and snippets.

@hieuphung97
Created October 30, 2020 07:34
Show Gist options
  • Select an option

  • Save hieuphung97/353d9f990b89a9c08da9190a24d02317 to your computer and use it in GitHub Desktop.

Select an option

Save hieuphung97/353d9f990b89a9c08da9190a24d02317 to your computer and use it in GitHub Desktop.
# Source: https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def one_hot(labels: torch.Tensor,
num_classes: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
where N is batch size. Each value is an integer
representing correct classification.
num_classes (int): number of classes in labels.
device (Optional[torch.device]): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see torch.set_default_tensor_type()). device will be the CPU for CPU
tensor types and the current CUDA device for CUDA tensor types.
dtype (Optional[torch.dtype]): the desired data type of returned
tensor. Default: if None, infers data type from values.
Returns:
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
Examples::
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> kornia.losses.one_hot(labels, num_classes=3)
tensor([[[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 0.]],
[[0., 0.],
[1., 0.]]]]
"""
if not torch.is_tensor(labels):
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
.format(type(labels)))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}" .format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
" Got: {}".format(num_classes))
shape = labels.shape
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
# based on:
# https://github.com/zhezh/focalloss/blob/master/focalloss.py
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'none',
eps: float = 1e-8) -> torch.Tensor:
r"""Function that computes Focal loss.
See :class:`~kornia.losses.FocalLoss` for details.
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
.format(input.shape))
if input.size(0) != target.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {} and {}" .format(
input.device, target.device))
# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1) + eps
# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(
target, num_classes=input.shape[1],
device=input.device, dtype=input.dtype)
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alpha * weight * torch.log(input_soft)
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
class FocalLoss(nn.Module):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> loss = kornia.losses.FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
def __init__(self, alpha: float, gamma: float = 2.0,
reduction: str = 'none') -> None:
super(FocalLoss, self).__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float = 1e-6
def forward( # type: ignore
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
return focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment