SegNeXt/predict.py

115 lines
3.1 KiB
Python
Raw Normal View History

2023-04-07 22:29:25 +08:00
import os
import numpy as np
import yaml
from PIL import Image
import data_utils
import torch
from pathlib import Path
import model_utils
import utils
from matplotlib import pyplot as plt
"""
预测
@:param net: 网络模型
@:param image: 图像
@:param cls_name: 类别名
@:param predict_config: 预测配置文件路径
@:param model_config: 模型配置文件路径
@:return mask: [image_height, image_width]元素类型为bool
"""
def predict(
net,
image: Image,
cls_name,
predict_config=Path("config") / "predict.yaml",
model_config=Path("config") / "model.yaml"
):
with model_config.open("r", encoding="utf-8") as mcf:
model_config = yaml.load(mcf, Loader=yaml.FullLoader)
classes = model_config["classes"]
with predict_config.open("r", encoding="utf-8") as pcf:
predict_config = yaml.load(pcf, yaml.FullLoader)
device = predict_config["device"]
image = data_utils.pil2tensor(image, device)
if len(image.shape) == 3:
image = torch.unsqueeze(image, dim=0)
batch_size, _, image_height, image_width = image.shape
prediction = data_utils.inv_one_hot_of_outputs(
torch.transpose(
net(image),
-2,
-1
).reshape(batch_size, len(classes), image_height, image_width),
device
)
mask = torch.squeeze(
prediction == utils.get_label_of_cls(classes, cls_name)[0]
)
return mask
"""
将预测结果与原图混合
@:param net: 神经网络模型
@:param image: 原图
@:param mask: predict的对应某一类别的mask
@:param mask: 神经网络的预测结果
@:param classes: 所有类别
@:param cls_name: 类别
@:param colors: 所有类别对应的颜色列表
@:return 混合后的图像
"""
def blend(
image: Image,
mask,
classes,
cls_name,
colors
):
mask = mask.to(device="cpu").numpy()
new_image = np.zeros((*mask.shape, 3), dtype=np.uint8)
new_image[mask] = utils.get_color_of_cls(classes, colors, cls_name)
new_image = Image.fromarray(new_image)
blend_image = Image.blend(image, new_image, 0.5)
return blend_image
"""
展示图像
@:param 需要进行展示的图像图像尺寸应为[height, width, channels=3]
"""
def show_image(image):
plt.imshow(image)
plt.show()
if __name__ == "__main__":
with Path(os.path.sep.join(["config", "model.yaml"])).open("r", encoding="utf-8") as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
classes = model_config["classes"]
colors = utils.get_colors(len(classes))
image_path = os.path.sep.join([
"dataset", "test", "biomass_image_train_0233_8.jpg"
])
cls_name = "leaf"
net = model_utils.get_model(False)
model_utils.init_model(False, net)
image = Image.open(image_path)
mask = predict(net, image, cls_name)
show_image(blend(image, mask, classes, cls_name, colors))