-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathfocalLoss.py
105 lines (79 loc) · 3.48 KB
/
focalLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# following:
# https://github.com/kornia/kornia/
# which is based on:
# https://github.com/zhezh/focalloss/blob/master/focalloss.py
def one_hot(
labels: torch.Tensor,
num_classes: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
eps: float = 1e-6,
) -> torch.Tensor:
if not isinstance(labels, torch.Tensor):
raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
if not labels.dtype == torch.int64:
raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes))
shape = labels.shape
one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'none',
eps: Optional[float] = None,
) -> torch.Tensor:
if eps is not None and not torch.jit.is_scripting():
warnings.warn(
"`focal_loss` has been reworked for improved numerical stability "
"and the `eps` argument is no longer necessary",
DeprecationWarning,
stacklevel=2,
)
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not len(input.shape) >= 2:
raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input.shape}")
if input.size(0) != target.size(0):
raise ValueError(f'Expected input batch_size ({input.size(0)}) to match target batch_size ({target.size(0)}).')
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError(f'Expected target size {out_size}, got {target.size()}')
if not input.device == target.device:
raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")
# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1)
log_input_soft: torch.Tensor = F.log_softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype)
# compute the actual focal loss
weight = torch.pow(-input_soft + 1.0, gamma)
focal = -alpha * weight * log_input_soft
loss_tmp = torch.einsum('bc...,bc...->b...', (target_one_hot, focal))
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError(f"Invalid reduction mode: {reduction}")
return loss
class FocalLoss(nn.Module):
def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: Optional[float] = None) -> None:
super().__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: Optional[float] = eps
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps)