SegNeXt/utils.py
2023-04-07 22:29:25 +08:00

537 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import colorsys
import copy
import json
import math
import os
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
from tabulate import tabulate
from torchvision.transforms import transforms, InterpolationMode
"""
生成num种颜色
返回值: color list返回的color list的第一个数值永远是(0, 0, 0)
"""
def get_colors(num: int):
assert num >= 1
if num <= 21:
colors = [
(0, 0, 0),
(128, 0, 0),
(0, 128, 0),
(128, 128, 0),
(0, 0, 128),
(128, 0, 128),
(0, 128, 128),
(128, 128, 128),
(64, 0, 0),
(192, 0, 0),
(64, 128, 0),
(192, 128, 0),
(64, 0, 128),
(192, 0, 128),
(64, 128, 128),
(192, 128, 128),
(0, 64, 0),
(128, 64, 0),
(0, 192, 0),
(128, 192, 0),
(0, 64, 128),
(128, 64, 12)
]
else:
hsv_tuples = [(x / num, 1., 1.) for x in range(0, num - 1)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
if (0, 0, 0) in colors:
colors.remove((0, 0, 0))
colors = [(0, 0, 0), *colors]
return colors
"""
获取某种颜色对应的标签
返回值:标签值
"""
def get_label_of_color(colors, color):
low_label = colors.index(color)
return low_label, 255 - low_label
"""
获取某个标签值对应的颜色
返回值:元组(r, g, b)
"""
def get_color_of_label(colors, label):
low_label = label if label < 255 - label else 255 - label
return colors[low_label]
"""
获取某种类别对应的标签
返回值:标签值
"""
def get_label_of_cls(classes, cls):
low_label = classes.index(cls)
return low_label, 255 - low_label
"""
获取某个标签值对应的类别
返回值:类别
"""
def get_cls_of_label(classes, label):
low_label = label if label < 255 - label else 255 - label
return classes[low_label]
"""
获取某种颜色对应的类别
返回值:类别
color: (r, g, b)
"""
def get_cls_of_color(classes, colors, color):
idx = colors.index(color)
return get_cls_of_label(classes, idx)
"""
获取某种类别对应的颜色
返回值:颜色,(r, g, b)
"""
def get_color_of_cls(classes, colors, cls):
idx = classes.index(cls)
return get_color_of_label(colors, idx)
def draw_mask(draw, points, shape_type, label, out_line_value, line_width=10, point_width=5):
points = [tuple(point) for point in points]
if shape_type == 'circle':
assert len(points) == 2, 'Shape of shape_type=circle must have 2 points'
(cx, cy), (px, py) = points
d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=out_line_value, fill=label)
elif shape_type == 'rectangle':
assert len(points) == 2, 'Shape of shape_type=rectangle must have 2 points'
draw.rectangle(points, outline=out_line_value, fill=label)
elif shape_type == 'line':
assert len(points) == 2, 'Shape of shape_type=line must have 2 points'
greater_label = out_line_value
draw.line(xy=points, fill=greater_label, width=line_width)
elif shape_type == 'linestrip':
greater_label = out_line_value
draw.line(xy=points, fill=greater_label, width=line_width)
elif shape_type == 'point':
assert len(points) == 1, 'Shape of shape_type=point must have 1 points'
cx, cy = points[0]
r = point_width
draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=out_line_value, fill=label)
else:
assert len(points) > 2, 'Polygon must have points more than 2'
draw.polygon(xy=points, outline=out_line_value, fill=label)
"""
负责将labelme的标记转换成(mask)图像
classes: 类别列表
"""
def labelme_json2mask(classes, json_path: str, mask_saved_path: str):
assert classes is not None and classes[0] == "background"
json_path = Path(json_path)
if json_path.exists() and json_path.is_file():
with json_path.open(mode="r") as f:
json_data = json.load(f)
image_height = json_data["imageHeight"]
image_width = json_data["imageWidth"]
image_path = json_data["imagePath"]
shapes = json_data["shapes"]
cls_info_list = []
for shape in shapes:
cls_name_in_json = shape["label"]
assert cls_name_in_json in classes
points = shape["points"]
shape_type = shape["shape_type"]
label_of_cls = classes.index(cls_name_in_json)
cls_info_list.append(
{
"cls_name": cls_name_in_json,
"label": label_of_cls,
"points": points,
"shape_type": shape_type
}
)
mask = np.zeros(shape=(image_height, image_width), dtype=np.uint8)
mask = Image.fromarray(mask)
draw = ImageDraw.Draw(mask)
for cls_info in cls_info_list:
points = cls_info["points"]
shape_type = cls_info["shape_type"]
label = cls_info["label"]
draw_mask(draw, points, shape_type, label, 255 - label)
mask = np.array(mask)
mask = Image.fromarray(mask)
mask.save(str(Path(mask_saved_path) / (str(image_path).split(".")[0] + ".png")))
os.remove(json_path)
"""
将root_path下labelme生成的json文件全部进行处理
1. 有原图匹配的json文件会转换成mask存储到mask_saved_path路径下
2. 没有原图但是有json文件的直接删除该json文件
3. 有原图但是没有json文件的会在mask_saved_path下生成一个纯黑背景图片
root_path: 存储着原图和json文件原图后缀名尽量为.jpg
"""
def convert_labelme_jsons2masks(classes, root_path: str, mask_saved_path: str, original_image_suffix=".jpg"):
assert 0 < len(classes) <= 128
original_images = set(
map(
lambda name: str(name).split(".")[0],
Path(root_path).glob(pattern=f"*{original_image_suffix}")
)
)
json_files = Path(root_path).glob(pattern="*.json")
for json_file in json_files:
filename = str(json_file).split(".")[0]
if filename in original_images:
labelme_json2mask(classes, str(json_file), mask_saved_path)
original_images.remove(filename)
else:
os.remove(json_file)
if len(original_images) != 0:
for image_filename in original_images:
image_path = image_filename + f"{original_image_suffix}"
image = Image.open(image_path)
height, width = image.height, image.width
image.close()
mask = np.zeros((height, width), dtype=np.uint8)
mask = Image.fromarray(mask)
mask.save(str(Path(mask_saved_path) / (os.path.basename(image_filename) + ".png")))
"""
将混淆矩阵得到的尺度(scores)组合成表格形式输出到控制台
scores: 混淆矩阵的尺度(scores)
"""
def confusion_matrix_scores2table(scores):
assert scores is not None and isinstance(scores, dict)
classes = [tp[0] for tp in scores["classes_precision"]]
cls_precision_list = [tp[-1] for tp in scores["classes_precision"]]
cls_recall_list = [tp[-1] for tp in scores["classes_recall"]]
cls_iou_list = [tp[-1] for tp in scores["classes_iou"]]
table1 = tabulate(
tabular_data=np.concatenate(
(
np.asarray(classes).reshape(-1, 1),
np.asarray(cls_precision_list).reshape(-1, 1),
np.asarray(cls_recall_list).reshape(-1, 1),
np.asarray(cls_iou_list).reshape(-1, 1)
), 1
),
headers=["classes", "precision", "recall", "iou"],
tablefmt="grid"
)
avg_precision = scores["avg_precision"]
avg_recall = scores["avg_recall"]
avg_iou = scores["avg_iou"]
accuracy = scores["accuracy"]
table2 = tabulate(
tabular_data=[(avg_precision, avg_recall, avg_iou, accuracy)],
headers=["avg_precision", "avg_recall", "avg_iou", "accuracy"],
tablefmt="grid"
)
table = tabulate(
tabular_data=np.concatenate(
(
np.asarray(["single", "overall"]).reshape(-1, 1),
np.asarray([table1, table2]).reshape(-1, 1)
), 1
),
headers=["table type", "table"],
tablefmt="grid"
)
return table
"""
相加混淆矩阵得到的两个scores
返回值:
相加后的混淆矩阵
"""
def sum_2_confusion_matrix_scores(scores_left: dict, scores_right: dict):
scores_left["classes_precision"] = [
(tp[0][0], tp[0][-1] + tp[-1][-1]) for tp in zip(scores_left["classes_precision"], scores_right["classes_precision"])
]
scores_left["classes_recall"] = [
(tp[0][0], tp[0][-1] + tp[-1][-1]) for tp in zip(scores_left["classes_recall"], scores_right["classes_recall"])
]
scores_left["classes_iou"] = [
(tp[0][0], tp[0][-1] + tp[-1][-1]) for tp in zip(scores_left["classes_iou"], scores_right["classes_iou"])
]
scores_left["avg_precision"] = scores_left["avg_precision"] + scores_right["avg_precision"]
scores_left["avg_recall"] = scores_left["avg_recall"] + scores_right["avg_recall"]
scores_left["avg_iou"] = scores_left["avg_iou"] + scores_right["avg_iou"]
scores_left["accuracy"] = scores_left["accuracy"] + scores_right["accuracy"]
return scores_left
"""
将混淆矩阵列表内的scores进行相加
@:param scores_list: 得分列表
@:return 相加后的得分
"""
def sum_confusion_matrix_scores_list(scores_list):
if len(scores_list) == 1:
return scores_list[0]
result = scores_list[0]
for i in range(1, len(scores_list)):
result = sum_2_confusion_matrix_scores(result, scores_list[i])
return result
"""
对混淆矩阵得出的scores_list相加后求平均
返回值:
相加后求平均的scores
"""
def avg_confusion_matrix_scores_list(scores_list):
assert scores_list is not None and len(scores_list) >= 1
result = sum_confusion_matrix_scores_list(scores_list)
result["classes_precision"] = [
(tp[0], tp[-1] / len(scores_list)) for tp in result["classes_precision"]
]
result["classes_recall"] = [
(tp[0], tp[-1] / len(scores_list)) for tp in result["classes_recall"]
]
result["classes_iou"] = [
(tp[0], tp[-1] / len(scores_list)) for tp in result["classes_iou"]
]
result["avg_precision"] = result["avg_precision"] / len(scores_list)
result["avg_recall"] = result["avg_recall"] / len(scores_list)
result["avg_iou"] = result["avg_iou"] / len(scores_list)
result["accuracy"] = result["accuracy"] / len(scores_list)
return result
"""
对原始作为x的输入图像进行增强预处理产生相同大小的图片(旋转、翻转、亮度调整)
ts是pytorch工具包经过该工具包处理后图像如果和原本的不同
就会保存在磁盘上,以达到增强数据的目的,请先执行该函数之后,再对原始数
据图像进行人工标注。
root_path目录下的数据只有图片且图片后缀名一致
root_path: 作为x的原始输入图像所在目录
ts: 预处理策略
"""
def augment_raw_images2(
root_path,
ts=transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)
]
)
):
image_paths = Path(root_path).glob(pattern="*")
for image_path in image_paths:
counter = 0
image_filename, image_suffix = os.path.splitext(image_path)
image = Image.open(image_path)
image_np = np.array(image)
for transform in ts.transforms:
new_image = transform(Image.fromarray(image_np))
new_image_np = np.array(new_image)
if not np.array_equal(image_np, new_image_np):
new_image_copy = Image.fromarray(new_image_np)
new_image_copy.save(str(Path(f"{image_filename}_{counter}{image_suffix}")))
new_image_copy.close()
counter += 1
new_image.close()
image.close()
"""
对原始作为x的输入图像进行增强预处理产生image_cropped_shape大小的图片
现将图像resize为image_resized_shape大小然后进行1次裁剪和1次随机裁剪裁剪的图像保留下来原始图像放入to_path中
ts是pytorch工具包经过该工具包处理后图像如果和原本的不同
就会保存在磁盘上,以达到增强数据的目的,请先执行该函数之后,再对原始数
据图像进行人工标注。
from_path目录下的数据只有图片且图片后缀名一致
from_path: 作为x的原始输入图像所在目录
to_path: 处理后的原始图像放入哪里如果为None就删除原始图像
image_resized_shape: 图像resize之后的大小, image_cropped_shape每个维度必须小于image_resized_shape
image_cropped_shape: 图像裁剪后的大小image_cropped_shape每个维度必须小于image_resized_shape
ts: 预处理策略
"""
def augment_raw_images(
from_path,
to_path="to/path",
image_resized_shape=(256, 256),
image_cropped_shape=(224, 224),
ts=None
):
if ts is None:
ts = transforms.Compose(
[
transforms.Resize(image_resized_shape, interpolation=InterpolationMode.BILINEAR),
transforms.RandomCrop(image_cropped_shape),
transforms.RandomResizedCrop(image_cropped_shape)
]
)
image_paths = Path(from_path).glob("*")
for image_path in image_paths:
counter = 0
image_filename, image_suffix = os.path.splitext(image_path)
with Image.open(image_path) as image:
image = ts.transforms[0](image)
image_copy_np = copy.deepcopy(np.array(image))
for transform in ts.transforms[0:]:
image = transform(image)
image_np = np.array(image)
if not np.array_equal(image_np, image_copy_np):
image.save(str(Path(f"{image_filename}_{counter}{image_suffix}")))
counter = counter + 1
image.close()
image = Image.fromarray(image_copy_np)
if to_path:
Path(image_path).rename(Path(to_path) / f"{os.path.basename(image_path)}")
else:
Path(image_path).unlink()
"""
对验证数据集中的图片进行大小的统一,以便其拥有统一的大小,可以进行批次训练
from_path: 验证数据集所在的目录
to_path: 原始数据应该转移到哪里
resized_shape: (height, width), resize后的大小
"""
def resize_val_images(from_path, to_path, resized_shape):
image_paths = Path(from_path).glob(pattern="*")
for image_path in image_paths:
original_image = Image.open(image_path)
original_image_np = np.array(original_image)
resized_image = Image.fromarray(original_image_np).resize(resized_shape)
original_image.close()
if not to_path:
Path(image_path).unlink(missing_ok=True)
else:
Path(image_path).rename(Path(to_path) / os.path.basename(image_path))
resized_image.save(image_path)
resized_image.close()
"""
将一张图片按照尺寸裁剪为多张图片
@:param image: 图片
@:param crop_size: 裁剪尺寸为tuple(image_height, image_width)
@:return 裁剪之后的图片列表
"""
def crop_image2images(image: Image, crop_size):
image_np = np.array(image)
image_height, image_width = image_np.shape[:-1]
left_image_height, left_image_width = image_np.shape[:-1]
crop_height, crop_width = crop_size
left_upper = (0, 0)
right_lower = (crop_width, crop_height)
image_list = []
while left_image_width / crop_width >= 1 or left_image_height / crop_height >= 1:
if left_image_width / crop_width >= 1 and left_image_height / crop_height >= 1:
new_image = image.crop((*left_upper, *right_lower))
left_image_width -= crop_width
left_upper = (left_upper[0] + crop_width, left_upper[-1])
right_lower = (right_lower[0] + crop_width, right_lower[-1])
image_list.append(new_image)
elif left_image_height / crop_height >= 1:
left_image_width = image_width
left_image_height -= crop_height
left_upper = (0, image_height - left_image_height)
right_lower = (crop_width, image_height - left_image_height + crop_height)
else:
break
return image_list
"""
将目录下的所有图片进行裁剪
@:param root_path: 图片的目录
@:param to: 原图片应该转移到哪里
@:param crop_size: 裁剪大小, tuple(crop_height, crop_width)
"""
def crop_images2small_images(root_path, to, crop_size):
image_paths = Path(root_path).glob(pattern="*")
for image_path in image_paths:
image = Image.open(image_path)
image_cropped_list = crop_image2images(image, crop_size)
for idx, image_cropped in enumerate(image_cropped_list):
image_cropped.save(
f"_{idx}".join(os.path.splitext(image_path))
)
image_cropped.close()
image.close()
if to is None:
Path(image_path).unlink(missing_ok=True)
else:
Path(image_path).rename(
str(
Path(to) / os.path.basename(image_path)
)
)
"""
判断是否能够多gpu分布式并行运算
"""
def distributed_enabled():
return torch.cuda.is_available() and torch.cuda.device_count() > 1 and torch.__version__ >= "0.4.0"
if __name__ == "__main__":
# crop_images2small_images(
# root_path="dataset/train/images",
# to=None,
# crop_size=(512, 512)
# )
# augment_raw_images2(root_path="dataset/train/images")
crop_images2small_images(
root_path="dataset/test",
to=None,
crop_size=(512, 512)
)
# augment_raw_images2(root_path="dataset/val/images")
# resize_val_images(
# from_path="dataset/test",
# to_path=None,
# resized_shape=(1024, 1024)
# )
# convert_labelme_jsons2masks(
# classes=[
# "background",
# "leaf"
# ],
# root_path="dataset/train/images",
# mask_saved_path="dataset/train/labels"
# )