SegNeXt/losses.py

28 lines
810 B
Python
Raw Normal View History

2023-04-07 22:29:25 +08:00
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
"""
weight: 每一种类别的权重越大说明该类别越重要
[weight_1, weight_2, ...]
len(weight) = classes_num
gamma: 为0表示关闭该参数的影响如果需要使用范围应为(0.5, 10.0)
"""
def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
2023-04-07 22:29:25 +08:00
def forward(self, x, y):
logp = self.ce(x, y)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
return loss.mean()
if __name__ == "__main__":
pass