115 lines
3.1 KiB
Python
115 lines
3.1 KiB
Python
![]() |
import os
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import yaml
|
|||
|
from PIL import Image
|
|||
|
import data_utils
|
|||
|
import torch
|
|||
|
from pathlib import Path
|
|||
|
import model_utils
|
|||
|
import utils
|
|||
|
from matplotlib import pyplot as plt
|
|||
|
|
|||
|
|
|||
|
|
|||
|
"""
|
|||
|
预测
|
|||
|
@:param net: 网络模型
|
|||
|
@:param image: 图像
|
|||
|
@:param cls_name: 类别名
|
|||
|
@:param predict_config: 预测配置文件路径
|
|||
|
@:param model_config: 模型配置文件路径
|
|||
|
|
|||
|
@:return mask: [image_height, image_width],元素类型为bool
|
|||
|
"""
|
|||
|
def predict(
|
|||
|
net,
|
|||
|
image: Image,
|
|||
|
cls_name,
|
|||
|
predict_config=Path("config") / "predict.yaml",
|
|||
|
model_config=Path("config") / "model.yaml"
|
|||
|
):
|
|||
|
with model_config.open("r", encoding="utf-8") as mcf:
|
|||
|
model_config = yaml.load(mcf, Loader=yaml.FullLoader)
|
|||
|
classes = model_config["classes"]
|
|||
|
|
|||
|
with predict_config.open("r", encoding="utf-8") as pcf:
|
|||
|
predict_config = yaml.load(pcf, yaml.FullLoader)
|
|||
|
device = predict_config["device"]
|
|||
|
image = data_utils.pil2tensor(image, device)
|
|||
|
if len(image.shape) == 3:
|
|||
|
image = torch.unsqueeze(image, dim=0)
|
|||
|
batch_size, _, image_height, image_width = image.shape
|
|||
|
|
|||
|
prediction = data_utils.inv_one_hot_of_outputs(
|
|||
|
torch.transpose(
|
|||
|
net(image),
|
|||
|
-2,
|
|||
|
-1
|
|||
|
).reshape(batch_size, len(classes), image_height, image_width),
|
|||
|
device
|
|||
|
)
|
|||
|
|
|||
|
mask = torch.squeeze(
|
|||
|
prediction == utils.get_label_of_cls(classes, cls_name)[0]
|
|||
|
)
|
|||
|
|
|||
|
return mask
|
|||
|
|
|||
|
|
|||
|
"""
|
|||
|
将预测结果与原图混合
|
|||
|
|
|||
|
@:param net: 神经网络模型
|
|||
|
@:param image: 原图
|
|||
|
@:param mask: predict的对应某一类别的mask
|
|||
|
@:param mask: 神经网络的预测结果
|
|||
|
@:param classes: 所有类别
|
|||
|
@:param cls_name: 类别
|
|||
|
@:param colors: 所有类别对应的颜色列表
|
|||
|
@:return 混合后的图像
|
|||
|
"""
|
|||
|
def blend(
|
|||
|
image: Image,
|
|||
|
mask,
|
|||
|
classes,
|
|||
|
cls_name,
|
|||
|
colors
|
|||
|
):
|
|||
|
mask = mask.to(device="cpu").numpy()
|
|||
|
new_image = np.zeros((*mask.shape, 3), dtype=np.uint8)
|
|||
|
new_image[mask] = utils.get_color_of_cls(classes, colors, cls_name)
|
|||
|
new_image = Image.fromarray(new_image)
|
|||
|
blend_image = Image.blend(image, new_image, 0.5)
|
|||
|
return blend_image
|
|||
|
|
|||
|
|
|||
|
|
|||
|
"""
|
|||
|
展示图像
|
|||
|
@:param 需要进行展示的图像,图像尺寸应为[height, width, channels=3]
|
|||
|
"""
|
|||
|
def show_image(image):
|
|||
|
plt.imshow(image)
|
|||
|
plt.show()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
with Path(os.path.sep.join(["config", "model.yaml"])).open("r", encoding="utf-8") as f:
|
|||
|
model_config = yaml.load(f, Loader=yaml.FullLoader)
|
|||
|
classes = model_config["classes"]
|
|||
|
|
|||
|
colors = utils.get_colors(len(classes))
|
|||
|
|
|||
|
image_path = os.path.sep.join([
|
|||
|
"dataset", "test", "biomass_image_train_0233_8.jpg"
|
|||
|
])
|
|||
|
|
|||
|
cls_name = "leaf"
|
|||
|
net = model_utils.get_model(False)
|
|||
|
model_utils.init_model(False, net)
|
|||
|
image = Image.open(image_path)
|
|||
|
mask = predict(net, image, cls_name)
|
|||
|
show_image(blend(image, mask, classes, cls_name, colors))
|