71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
![]() |
from torch.utils.data import Dataset, DataLoader
|
|||
|
import numpy as np
|
|||
|
from PIL import Image
|
|||
|
import os
|
|||
|
from torchvision import transforms
|
|||
|
from torch.utils.tensorboard import SummaryWriter
|
|||
|
from torchvision.utils import make_grid
|
|||
|
|
|||
|
writer = SummaryWriter("logs")
|
|||
|
|
|||
|
class MyData(Dataset):
|
|||
|
|
|||
|
def __init__(self, root_dir, image_dir, label_dir, transform):
|
|||
|
self.root_dir = root_dir
|
|||
|
self.image_dir = image_dir
|
|||
|
self.label_dir = label_dir
|
|||
|
self.label_path = os.path.join(self.root_dir, self.label_dir)
|
|||
|
self.image_path = os.path.join(self.root_dir, self.image_dir)
|
|||
|
self.image_list = os.listdir(self.image_path)
|
|||
|
self.label_list = os.listdir(self.label_path)
|
|||
|
self.transform = transform
|
|||
|
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
|
|||
|
self.image_list.sort()
|
|||
|
self.label_list.sort()
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
img_name = self.image_list[idx]
|
|||
|
label_name = self.label_list[idx]
|
|||
|
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
|
|||
|
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
|
|||
|
img = Image.open(img_item_path)
|
|||
|
|
|||
|
with open(label_item_path, 'r') as f:
|
|||
|
label = f.readline()
|
|||
|
|
|||
|
# img = np.array(img)
|
|||
|
img = self.transform(img)
|
|||
|
sample = {'img': img, 'label': label}
|
|||
|
return sample
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
assert len(self.image_list) == len(self.label_list)
|
|||
|
return len(self.image_list)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
|
|||
|
root_dir = "dataset/train"
|
|||
|
image_ants = "ants_image"
|
|||
|
label_ants = "ants_label"
|
|||
|
ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
|
|||
|
image_bees = "bees_image"
|
|||
|
label_bees = "bees_label"
|
|||
|
bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
|
|||
|
train_dataset = ants_dataset + bees_dataset
|
|||
|
|
|||
|
# transforms = transforms.Compose([transforms.Resize(256, 256)])
|
|||
|
dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)
|
|||
|
|
|||
|
writer.add_image('error', train_dataset[119]['img'])
|
|||
|
writer.close()
|
|||
|
for i, j in enumerate(dataloader):
|
|||
|
# imgs, labels = j
|
|||
|
print(type(j))
|
|||
|
print(i, j['img'].shape)
|
|||
|
# writer.add_image("train_data_b2", make_grid(j['img']), i)
|
|||
|
|
|||
|
writer.close()
|
|||
|
|
|||
|
|
|||
|
|