71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
||
import numpy as np
|
||
from PIL import Image
|
||
import os
|
||
from torchvision import transforms
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
from torchvision.utils import make_grid
|
||
|
||
writer = SummaryWriter("logs")
|
||
|
||
class MyData(Dataset):
|
||
|
||
def __init__(self, root_dir, image_dir, label_dir, transform):
|
||
self.root_dir = root_dir
|
||
self.image_dir = image_dir
|
||
self.label_dir = label_dir
|
||
self.label_path = os.path.join(self.root_dir, self.label_dir)
|
||
self.image_path = os.path.join(self.root_dir, self.image_dir)
|
||
self.image_list = os.listdir(self.image_path)
|
||
self.label_list = os.listdir(self.label_path)
|
||
self.transform = transform
|
||
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
|
||
self.image_list.sort()
|
||
self.label_list.sort()
|
||
|
||
def __getitem__(self, idx):
|
||
img_name = self.image_list[idx]
|
||
label_name = self.label_list[idx]
|
||
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
|
||
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
|
||
img = Image.open(img_item_path)
|
||
|
||
with open(label_item_path, 'r') as f:
|
||
label = f.readline()
|
||
|
||
# img = np.array(img)
|
||
img = self.transform(img)
|
||
sample = {'img': img, 'label': label}
|
||
return sample
|
||
|
||
def __len__(self):
|
||
assert len(self.image_list) == len(self.label_list)
|
||
return len(self.image_list)
|
||
|
||
if __name__ == '__main__':
|
||
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
|
||
root_dir = "dataset/train"
|
||
image_ants = "ants_image"
|
||
label_ants = "ants_label"
|
||
ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
|
||
image_bees = "bees_image"
|
||
label_bees = "bees_label"
|
||
bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
|
||
train_dataset = ants_dataset + bees_dataset
|
||
|
||
# transforms = transforms.Compose([transforms.Resize(256, 256)])
|
||
dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)
|
||
|
||
writer.add_image('error', train_dataset[119]['img'])
|
||
writer.close()
|
||
for i, j in enumerate(dataloader):
|
||
# imgs, labels = j
|
||
print(type(j))
|
||
print(i, j['img'].shape)
|
||
# writer.add_image("train_data_b2", make_grid(j['img']), i)
|
||
|
||
writer.close()
|
||
|
||
|
||
|