SegNeXt/train.py

286 lines
9.9 KiB
Python

import math
import os.path
import numpy as np
import torch
import yaml
from pathlib import Path
from tqdm import tqdm
import utils
import data_utils
import model_utils
from torch.utils.data import DataLoader
import losses
from datetime import datetime
"""
1 epoch train
@:param epochs: 总共的epoch数
@:param epoch: 当前epoch
@:param net: 神经网络模型
@:param train_data_loader: 训练数据加载器
@:param image_size: 图片大小
@:param classes_num: 类别数
@:param loss_fn: 损失函数
@:param lr_scheduler: 学习率调度器
@:param optimizer: 优化器
@:param device: 运行场地
@:return 1 epoch train avg loss, 1 epoch train avg scores
"""
def fit(
epochs,
epoch,
net,
train_data_loader,
image_size,
classes_num,
loss_fn,
lr_scheduler,
optimizer,
device="cuda"
):
matrix = data_utils.ConfusionMatrix(classes_num)
scores_list = []
loss_list = []
progress_bar = tqdm(train_data_loader)
for idx, data in enumerate(progress_bar):
images, labels = data
lr_scheduler.step()
optimizer.zero_grad()
predictions = torch.transpose(net(images), -2, -1).view(-1, classes_num, *image_size)
matrix.update(labels, data_utils.inv_one_hot_of_outputs(predictions, device), device)
scores = matrix.get_scores()
matrix.reset()
scores_list.append(scores)
loss = loss_fn(
predictions,
torch.squeeze(labels, dim=1).to(dtype=torch.long)
)
loss_value = loss.item()
if np.isnan(loss_value):
loss_value = max(loss_list) if len(loss_list) != 0 else 1.0
loss_list.append(loss_value)
loss.backward()
optimizer.step()
progress_bar.set_description(
f"train --> Epoch {epoch + 1} / {epochs}, batch_loss: {loss_value:.3f}, batch_iou: {scores['avg_iou']:.3f}, batch_accuracy: {scores['accuracy']:.3f}"
)
progress_bar.close()
return sum(loss_list) / len(loss_list), utils.avg_confusion_matrix_scores_list(scores_list)
"""
1 epoch train
@:param epochs: 总共的epoch数
@:param epoch: 当前epoch
@:param net: 神经网络模型
@:param train_data_loader: 验证数据加载器
@:param image_size: 图片大小
@:param classes_num: 类别数
@:param loss_fn: 损失函数
@:param device: 运行场地
@:return val avg loss, val avg scores
"""
@torch.no_grad()
def val(
epochs,
epoch,
net,
val_data_loader,
image_size,
classes_num,
loss_fn,
device="cuda"
):
matrix = data_utils.ConfusionMatrix(classes_num)
scores_list = []
loss_list = []
progress_bar = tqdm(val_data_loader)
for idx, data in enumerate(progress_bar):
images, labels = data
predictions = torch.transpose(net(images), -2, -1).view(-1, classes_num, *image_size)
matrix.update(labels, data_utils.inv_one_hot_of_outputs(predictions, device), device)
scores = matrix.get_scores()
matrix.reset()
scores_list.append(scores)
loss = loss_fn(
predictions,
torch.squeeze(labels, dim=1).to(dtype=torch.long)
)
loss_value = loss.item()
if np.isnan(loss_value):
loss_value = max(loss_list) if len(loss_list) != 0 else 1.0
loss_list.append(loss_value)
progress_bar.set_description(
f"val ---> Epoch {epoch + 1} / {epochs}, batch_loss: {loss_value:.3f}, batch_iou: {scores['avg_iou']:.3f}, batch_accuracy: {scores['accuracy']:.3f}"
)
progress_bar.close()
return sum(loss_list) / len(loss_list), utils.avg_confusion_matrix_scores_list(scores_list)
"""
模型训练
net: 网络模型
optimizer: 优化器,
lr_scheduler: 学习率调度器,
weight: 每一类的权重
root_path: 存储训练数据和验证数据的根目录
train_dir_names: 存储训练数据的目录,元组形式(images_path, labels_path)
val_dir_names: 存储验证数据的目录, 元组形式(images_path, labels_path)
classes_num: 类别数量
yaml_path: 配置文件路径
"""
def train(
net,
optimizer,
lr_scheduler,
train_config=Path("config") / "train.yaml",
model_config=Path("config") / "model.yaml"
):
with model_config.open("r", encoding="utf-8") as mcf:
model_config = yaml.load(mcf, yaml.FullLoader)
classes_num = len(model_config["classes"])
with train_config.open("r", encoding="utf-8") as tcf:
train_config = yaml.load(tcf, Loader=yaml.Loader)
device = train_config["device"]
epochs = train_config["epochs"]
train_images_dataset = data_utils.Pic2PicDataset(
root=os.path.sep.join(train_config["root"]),
x_dir_name=Path(os.path.sep.join(train_config["train_dir_name"])) / train_config["images_dir_name"],
y_dir_name=Path(os.path.sep.join(train_config["train_dir_name"])) / train_config["labels_dir_name"]
)
train_data_loader = DataLoader(
dataset=train_images_dataset,
batch_size=train_config["batch_size"],
shuffle=True,
num_workers=train_config["workers"]
)
val_images_dataset = data_utils.Pic2PicDataset(
root=os.path.sep.join(train_config["root"]),
x_dir_name=Path(os.path.sep.join(train_config["val_dir_name"])) / train_config["images_dir_name"],
y_dir_name=Path(os.path.sep.join(train_config["val_dir_name"])) / train_config["labels_dir_name"]
)
val_data_loader = DataLoader(
dataset=val_images_dataset,
batch_size=train_config["batch_size"],
shuffle=False,
num_workers=train_config["workers"]
)
image_height, image_width = train_config["image_height"], train_config["image_width"]
weight = torch.tensor(train_config["weight"]) if len(train_config["weight"]) != 1 else torch.ones(classes_num)
loss_fn = losses.FocalLoss(
weight=weight.to(device)
)
max_train_iou, max_val_iou = -np.inf, -np.inf
best_train_model, best_val_model = None, None
for epoch in range(0, epochs):
# 训练
net.train()
train_avg_loss, train_avg_scores = fit(
epochs=epochs,
epoch=epoch,
net=net,
train_data_loader=train_data_loader,
image_size=(image_height, image_width),
classes_num=classes_num,
loss_fn=loss_fn,
lr_scheduler=lr_scheduler,
optimizer=optimizer,
device=device
)
print()
print(utils.confusion_matrix_scores2table(train_avg_scores))
print(f"train_avg_loss: {train_avg_loss:.3f}")
if max_train_iou < train_avg_scores["avg_iou"]:
max_train_iou = train_avg_scores["avg_iou"]
best_train_model = {
"state_dict": net.state_dict(),
"optimizer": optimizer.state_dict(),
"avg_iou": max_train_iou
}
# 验证
if (epoch + 1) % train_config["eval_every_n_epoch"] == 0:
net.eval()
val_avg_loss, val_avg_scores = val(
epochs=epochs,
epoch=epoch,
net=net,
val_data_loader=val_data_loader,
image_size=(image_height, image_width),
classes_num=classes_num,
loss_fn=loss_fn,
device=device
)
print()
print(utils.confusion_matrix_scores2table(val_avg_scores))
print(f"val_avg_loss: {val_avg_loss:.3f}")
if max_val_iou < val_avg_scores["avg_iou"]:
max_val_iou = val_avg_scores["avg_iou"]
best_val_model = {
"state_dict": net.state_dict(),
"optimizer": optimizer.state_dict(),
"avg_iou": max_val_iou
}
m = {
"state_dict": net.state_dict(),
"optimizer": optimizer.state_dict(),
"avg_iou": val_avg_scores["avg_iou"]
}
torch.save(
obj=m,
f=f"{os.path.sep.join(train_config['save_path'])}_Iou{100 * best_val_model['avg_iou']:.3f}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}.pth"
)
torch.save(
obj=best_train_model,
f=f"{os.path.sep.join(train_config['save_path'])}_train_Iou{100 * best_train_model['avg_iou']:.3f}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}.pth"
)
torch.save(
obj=best_train_model,
f=f"{os.path.sep.join(train_config['save_path'])}_val_Iou{100 * best_val_model['avg_iou']:.3f}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}.pth"
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
net = model_utils.get_model(True)
optimizer = model_utils.get_optimizer(net)
lr_scheduler = model_utils.get_lr_scheduler(optimizer=optimizer)
model_utils.init_model(
train=True,
net=net,
optimizer=optimizer
)
# 计算并打印网络的参数量
# num_params = count_parameters(net)
# print(f"网络的参数量: {num_params}")
train(
net=net,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)