SegNeXt/losses.py

28 lines
810 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class FocalLoss(nn.Module):
"""
weight: 每一种类别的权重,越大,说明该类别越重要
[weight_1, weight_2, ...]
len(weight) = classes_num
gamma: 为0表示关闭该参数的影响如果需要使用范围应为(0.5, 10.0)
"""
def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
def forward(self, x, y):
logp = self.ce(x, y)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
return loss.mean()
if __name__ == "__main__":
pass