调整学习率调度器的预热迭代次数,更新训练配置中的类别权重,修改损失函数中的权重,优化测试笔记本中的代码结构
This commit is contained in:
parent
294425fce4
commit
125c1ab6bd
@ -3,7 +3,7 @@ scheduler_type: LINEAR_WARMUP_THEN_POLY_SCHEDULER
|
|||||||
# total_iters=epochs * 训练的图像数量
|
# total_iters=epochs * 训练的图像数量
|
||||||
kwargs: |
|
kwargs: |
|
||||||
{
|
{
|
||||||
"warmup_iters": 4356,
|
"warmup_iters": 2904,
|
||||||
"total_iters": 290400,
|
"total_iters": 290400,
|
||||||
"warmup_ratio": 0.000001,
|
"warmup_ratio": 0.000001,
|
||||||
"min_lr": 0.,
|
"min_lr": 0.,
|
||||||
|
@ -8,6 +8,8 @@ epochs: 100
|
|||||||
# 每一类的占比权重,如果要让每一类的占比权重相同,为1.0即可
|
# 每一类的占比权重,如果要让每一类的占比权重相同,为1.0即可
|
||||||
weight:
|
weight:
|
||||||
- 1.0
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
|
||||||
# 数据集存放位置
|
# 数据集存放位置
|
||||||
root:
|
root:
|
||||||
|
@ -13,7 +13,8 @@ class FocalLoss(nn.Module):
|
|||||||
super(FocalLoss, self).__init__()
|
super(FocalLoss, self).__init__()
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
|
myweight = torch.tensor([1.0, 1.9, 1.0, 1.2]).cuda()
|
||||||
|
self.ce = torch.nn.CrossEntropyLoss(weight=myweight, reduction=reduction)
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
logp = self.ce(x, y)
|
logp = self.ce(x, y)
|
||||||
|
14
test.ipynb
14
test.ipynb
@ -2,17 +2,18 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"import numpy as np"
|
"import numpy as np\n",
|
||||||
|
"import os"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -84,6 +85,13 @@
|
|||||||
"print(image_array.shape)\n",
|
"print(image_array.shape)\n",
|
||||||
"print(image_array)"
|
"print(image_array)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user