diff --git a/config/train.yaml b/config/train.yaml index 945bdc9..a9e971b 100644 --- a/config/train.yaml +++ b/config/train.yaml @@ -8,8 +8,9 @@ epochs: 100 # 每一类的占比权重,如果要让每一类的占比权重相同,为1.0即可 weight: - 1.0 + - 1.9 - 1.0 - - 1.0 + - 1.2 # 数据集存放位置 root: diff --git a/losses.py b/losses.py index c9d5ca6..717b3ff 100644 --- a/losses.py +++ b/losses.py @@ -13,8 +13,7 @@ class FocalLoss(nn.Module): super(FocalLoss, self).__init__() self.gamma = gamma self.eps = eps - myweight = torch.tensor([1.0, 1.9, 1.0, 1.2]).cuda() - self.ce = torch.nn.CrossEntropyLoss(weight=myweight, reduction=reduction) + self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction) def forward(self, x, y): logp = self.ce(x, y)