添加 .gitignore 文件,更新训练和预测配置,修改模型类别,删除冗余的 readme 文件,增加测试脚本,优化学习率调度器参数
This commit is contained in:
parent
85379804a2
commit
4799da54ce
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
*.jpg
|
||||||
|
*.png
|
||||||
|
__pycache__
|
@ -3,8 +3,8 @@ scheduler_type: LINEAR_WARMUP_THEN_POLY_SCHEDULER
|
|||||||
# total_iters=epochs * 训练的图像数量
|
# total_iters=epochs * 训练的图像数量
|
||||||
kwargs: |
|
kwargs: |
|
||||||
{
|
{
|
||||||
"warmup_iters": 62,
|
"warmup_iters": 4356,
|
||||||
"total_iters": 620,
|
"total_iters": 290400,
|
||||||
"warmup_ratio": 0.000001,
|
"warmup_ratio": 0.000001,
|
||||||
"min_lr": 0.,
|
"min_lr": 0.,
|
||||||
"power": 1.
|
"power": 1.
|
||||||
|
@ -40,4 +40,6 @@ nmf2d_config:
|
|||||||
# 类别
|
# 类别
|
||||||
classes:
|
classes:
|
||||||
- background # 必须要
|
- background # 必须要
|
||||||
- leaf
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
@ -3,7 +3,7 @@ device: "cuda"
|
|||||||
# -1表示从不加载任何权重就进行预测
|
# -1表示从不加载任何权重就进行预测
|
||||||
# 0表示使用官方提供的权重进行预测
|
# 0表示使用官方提供的权重进行预测
|
||||||
# 1表示使用自己的权重进行预测
|
# 1表示使用自己的权重进行预测
|
||||||
mode: 0
|
mode: -1
|
||||||
checkpoint:
|
checkpoint:
|
||||||
- pretrained # 目录
|
- pretrained # 目录
|
||||||
- segnext_tiny_512x512_ade_160k.pth # 文件名
|
- segnext_tiny_512x512_ade_160k.pth # 文件名
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
device: cuda
|
device: cuda
|
||||||
batch_size: 2
|
batch_size: 16
|
||||||
image_height: 512
|
image_height: 200
|
||||||
image_width: 512
|
image_width: 200
|
||||||
workers: 0
|
workers: 0
|
||||||
epochs: 10
|
epochs: 100
|
||||||
|
|
||||||
# 每一类的占比权重,如果要让每一类的占比权重相同,为1.0即可
|
# 每一类的占比权重,如果要让每一类的占比权重相同,为1.0即可
|
||||||
weight:
|
weight:
|
||||||
@ -30,7 +30,7 @@ save_path:
|
|||||||
# -1表示从零开始训练网络,即不加载任何权重
|
# -1表示从零开始训练网络,即不加载任何权重
|
||||||
# 0表示使用官方提供的权重
|
# 0表示使用官方提供的权重
|
||||||
# 1表示使用自己的权重
|
# 1表示使用自己的权重
|
||||||
mode: 0
|
mode: -1
|
||||||
checkpoint:
|
checkpoint:
|
||||||
- pretrained # 目录
|
- pretrained # 目录
|
||||||
- segnext_tiny_512x512_ade_160k.pth # 文件名
|
- segnext_tiny_512x512_ade_160k.pth # 文件名
|
||||||
|
@ -1 +0,0 @@
|
|||||||
训练图像
|
|
@ -1 +0,0 @@
|
|||||||
标签
|
|
@ -1 +0,0 @@
|
|||||||
验证图像数据
|
|
@ -1 +0,0 @@
|
|||||||
标签数据
|
|
@ -187,7 +187,7 @@ class LinearWarmupThenPolyScheduler:
|
|||||||
total_iters: 总训练步数
|
total_iters: 总训练步数
|
||||||
min_lr: 最低学习率
|
min_lr: 最低学习率
|
||||||
"""
|
"""
|
||||||
def __init__(self, optimizer, warmup_iters=1500, total_iters=2000, warmup_ratio=1e-6, min_lr=0., power=1.):
|
def __init__(self, optimizer, warmup_iters=1500, total_iters=20000, warmup_ratio=1e-6, min_lr=0., power=1.):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.current_iters = 0
|
self.current_iters = 0
|
||||||
self.warmup_iters = warmup_iters
|
self.warmup_iters = warmup_iters
|
||||||
|
@ -1 +0,0 @@
|
|||||||
权重, 预训练权重文件请移步官方实现的仓库下载
|
|
110
test.ipynb
Normal file
110
test.ipynb
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import numpy as np"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"(200, 200, 3)\n",
|
||||||
|
"[[[226 226 226]\n",
|
||||||
|
" [227 227 227]\n",
|
||||||
|
" [225 225 225]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [153 153 153]\n",
|
||||||
|
" [153 153 153]\n",
|
||||||
|
" [152 152 152]]\n",
|
||||||
|
"\n",
|
||||||
|
" [[227 227 227]\n",
|
||||||
|
" [222 222 222]\n",
|
||||||
|
" [221 221 221]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [151 151 151]\n",
|
||||||
|
" [150 150 150]\n",
|
||||||
|
" [149 149 149]]\n",
|
||||||
|
"\n",
|
||||||
|
" [[227 227 227]\n",
|
||||||
|
" [226 226 226]\n",
|
||||||
|
" [221 221 221]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [153 153 153]\n",
|
||||||
|
" [152 152 152]\n",
|
||||||
|
" [153 153 153]]\n",
|
||||||
|
"\n",
|
||||||
|
" ...\n",
|
||||||
|
"\n",
|
||||||
|
" [[234 234 234]\n",
|
||||||
|
" [229 229 229]\n",
|
||||||
|
" [229 229 229]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [160 160 160]\n",
|
||||||
|
" [159 159 159]\n",
|
||||||
|
" [159 159 159]]\n",
|
||||||
|
"\n",
|
||||||
|
" [[231 231 231]\n",
|
||||||
|
" [233 233 233]\n",
|
||||||
|
" [230 230 230]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [159 159 159]\n",
|
||||||
|
" [160 160 160]\n",
|
||||||
|
" [159 159 159]]\n",
|
||||||
|
"\n",
|
||||||
|
" [[230 230 230]\n",
|
||||||
|
" [236 236 236]\n",
|
||||||
|
" [234 234 234]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [155 155 155]\n",
|
||||||
|
" [156 156 156]\n",
|
||||||
|
" [154 154 154]]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# 读取图片\n",
|
||||||
|
"image_path = 'dataset/train/images/000229.jpg'\n",
|
||||||
|
"image = Image.open(image_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# 将图片转换为 NumPy 数组\n",
|
||||||
|
"image_array = np.array(image)\n",
|
||||||
|
"\n",
|
||||||
|
"# 显示图片的像素值\n",
|
||||||
|
"print(image_array.shape)\n",
|
||||||
|
"print(image_array)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
8
train.py
8
train.py
@ -260,7 +260,8 @@ def train(
|
|||||||
f=f"{os.path.sep.join(train_config['save_path'])}_val_Iou{100 * best_val_model['avg_iou']:.3f}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}.pth"
|
f=f"{os.path.sep.join(train_config['save_path'])}_val_Iou{100 * best_val_model['avg_iou']:.3f}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}.pth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -273,6 +274,11 @@ if __name__ == "__main__":
|
|||||||
net=net,
|
net=net,
|
||||||
optimizer=optimizer
|
optimizer=optimizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 计算并打印网络的参数量
|
||||||
|
# num_params = count_parameters(net)
|
||||||
|
# print(f"网络的参数量: {num_params}")
|
||||||
|
|
||||||
train(
|
train(
|
||||||
net=net,
|
net=net,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user