Dataset
Dataset类是torch中的一个常用的用于数据数据集读取的类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| import torch from torch.utils.data import Dataset from PIL import Image import os
class MyDataset(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_dir = os.listdir(self.path)
def __getitem__(self,idx): img_name = self.img_dir[idx] img_item_pth = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_pth) img = img.convert('RGB') label = self.label_dir return img, label
def __len__(self): return len(self.img_dir)
if __name__ == '__main__': root_dir = 'Data/hymenoptera_data/train' ants_dir = 'ants' bees_dir = 'bees' ants = MyDataset(root_dir, ants_dir) bees = MyDataset(root_dir, bees_dir)
|