28 lines
810 B
Python
28 lines
810 B
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class FocalLoss(nn.Module):
|
||
"""
|
||
weight: 每一种类别的权重,越大,说明该类别越重要
|
||
[weight_1, weight_2, ...]
|
||
len(weight) = classes_num
|
||
gamma: 为0表示关闭该参数的影响,如果需要使用,范围应为(0.5, 10.0)
|
||
"""
|
||
def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
|
||
super(FocalLoss, self).__init__()
|
||
self.gamma = gamma
|
||
self.eps = eps
|
||
self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
|
||
|
||
def forward(self, x, y):
|
||
logp = self.ce(x, y)
|
||
p = torch.exp(-logp)
|
||
loss = (1 - p) ** self.gamma * logp
|
||
return loss.mean()
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pass |