2023-04-07 22:29:25 +08:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
weight: 每一种类别的权重,越大,说明该类别越重要
|
|
|
|
|
[weight_1, weight_2, ...]
|
|
|
|
|
len(weight) = classes_num
|
|
|
|
|
gamma: 为0表示关闭该参数的影响,如果需要使用,范围应为(0.5, 10.0)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
|
|
|
|
|
super(FocalLoss, self).__init__()
|
|
|
|
|
self.gamma = gamma
|
|
|
|
|
self.eps = eps
|
2024-10-15 19:51:05 +08:00
|
|
|
|
self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
|
2023-04-07 22:29:25 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
|
|
|
|
logp = self.ce(x, y)
|
|
|
|
|
p = torch.exp(-logp)
|
|
|
|
|
loss = (1 - p) ** self.gamma * logp
|
|
|
|
|
return loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
pass
|