pytorch导入自定义数据集
最近刚学图神经网络,数据集导入折腾了很久,终于开窍了一点。 目前常用的数据导入方法主要有两种:
(1)torchvision自带的导入方式: 这种导入方式使用了torchvision自带的库,打开函数进去看它的说明是这样的: 直接翻译过来意思就是图片要放在相应类别的文件夹下,文件夹名字就是图片所属的类别。
导入代码如下:
from torchvision import datasets transform可自行定义 train_transforms = transforms.Compose( [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=15), transforms.RandomHorizontalFlip(), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_dataset=datasets.ImageFolder(train_dir,transform=train_transforms)
2.自定义数据导入方式 现实使用过程中经常会遇到图片跟标签是分开放置的情况,如下面两张图所示,图片和label分别放置的,那么torchvision自带的库就不能用了,需要自定义数据读取方式。 首先用os库遍历文件,提取图片的名字和对应的label,保存在CSV文件中(当然完整的程序不保存也可以,这里是为了方便后面用),
开始自定义导入数据的类,这部分的格式都是统一的,最开始先写上这几个必须的函数,再往里面填东西:
from torch.utils.data import Dataset#Dataset是必须要继承的 class LoadData(Dataset): def __init__(self,image_path,transform=None): #初始化,读取数据集 def __getitem__(self,index): #对于指定id,获取该数据并返回 def __len__(self): #获取数据及总大小
确定模板以后直接往里面填东西就可以了:
from torch.utils.data import Dataset#Dataset是必须要继承的 import pandas as pd from PIL import Image class LoadData(Dataset): def __init__(self,image_path,transform=None): self.imgs_info=pd.read_csv(image_path) def __getitem__(self,index): img_path,label=self.imgs_info[img_path],self.imgs_info[weather] img=Image.open(img_path)#得到路径需要打开图片 img=img.convert(RGB)#将图片转为张量 if transform is not None: img=transform(img)#图像变换 returnimg,label def __len__(self): return len(self.imgs_info)
主函数中调用:
from torchvision import transforms train_csv_path=r./dataset/train.csv train_transforms=transforms.Compose( [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=15), transforms.RandomHorizontalFlip(), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_dataset=LoadData(train_csv_path,transform=train_transforms) train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)
上一篇:
JS实现多线程数据分片下载
下一篇:
深度学习准确率提升之天花板分析