SegNeXt/predict.py
2023-04-07 22:29:25 +08:00

115 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import yaml
from PIL import Image
import data_utils
import torch
from pathlib import Path
import model_utils
import utils
from matplotlib import pyplot as plt
"""
预测
@:param net: 网络模型
@:param image: 图像
@:param cls_name: 类别名
@:param predict_config: 预测配置文件路径
@:param model_config: 模型配置文件路径
@:return mask: [image_height, image_width]元素类型为bool
"""
def predict(
net,
image: Image,
cls_name,
predict_config=Path("config") / "predict.yaml",
model_config=Path("config") / "model.yaml"
):
with model_config.open("r", encoding="utf-8") as mcf:
model_config = yaml.load(mcf, Loader=yaml.FullLoader)
classes = model_config["classes"]
with predict_config.open("r", encoding="utf-8") as pcf:
predict_config = yaml.load(pcf, yaml.FullLoader)
device = predict_config["device"]
image = data_utils.pil2tensor(image, device)
if len(image.shape) == 3:
image = torch.unsqueeze(image, dim=0)
batch_size, _, image_height, image_width = image.shape
prediction = data_utils.inv_one_hot_of_outputs(
torch.transpose(
net(image),
-2,
-1
).reshape(batch_size, len(classes), image_height, image_width),
device
)
mask = torch.squeeze(
prediction == utils.get_label_of_cls(classes, cls_name)[0]
)
return mask
"""
将预测结果与原图混合
@:param net: 神经网络模型
@:param image: 原图
@:param mask: predict的对应某一类别的mask
@:param mask: 神经网络的预测结果
@:param classes: 所有类别
@:param cls_name: 类别
@:param colors: 所有类别对应的颜色列表
@:return 混合后的图像
"""
def blend(
image: Image,
mask,
classes,
cls_name,
colors
):
mask = mask.to(device="cpu").numpy()
new_image = np.zeros((*mask.shape, 3), dtype=np.uint8)
new_image[mask] = utils.get_color_of_cls(classes, colors, cls_name)
new_image = Image.fromarray(new_image)
blend_image = Image.blend(image, new_image, 0.5)
return blend_image
"""
展示图像
@:param 需要进行展示的图像,图像尺寸应为[height, width, channels=3]
"""
def show_image(image):
plt.imshow(image)
plt.show()
if __name__ == "__main__":
with Path(os.path.sep.join(["config", "model.yaml"])).open("r", encoding="utf-8") as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
classes = model_config["classes"]
colors = utils.get_colors(len(classes))
image_path = os.path.sep.join([
"dataset", "test", "biomass_image_train_0233_8.jpg"
])
cls_name = "leaf"
net = model_utils.get_model(False)
model_utils.init_model(False, net)
image = Image.open(image_path)
mask = predict(net, image, cls_name)
show_image(blend(image, mask, classes, cls_name, colors))