pytorch之将图片及str标签转为tensor
数据集的文件夹格式: (数据集的文件路径:file_path = r’ .Dataset )
Dataset |------train |------------000000 |------------000001 … … |------------000305 |------test |------------000000 |------------000001 … … |------------000305 |------valid |------------000000 |------------000001 … … |------------000305
path = ["test", "train", "valid"] # class306记录了索引对应的类别标签(str类型) class306 = ClassToList(data+"\test") train_set = [] test_set = [] valid_set = [] file_path = r .Dataset the_path = file_path+path[i] for class_path in os.listdir(the_path): print("reading "+the_path+"\"+class_path) for img_path in os.listdir(the_path+"\"+class_path): img = Image.open(the_path+"\"+class_path+"\"+img_path).convert(RGB) x.append(img_transforms(img)) # img_transforms(img)将PIL.Image读入的图片转为相应的格式 y.append(j) j += 1 print("Have read "+the_path+"\"+class_path) print("size: ", len(x)) if i == 1: train_set = MyDataSet(x, y, train_transforms) elif i == 0: test_set = MyDataSet(x, y, test_transforms) elif i == 2: valid_set = MyDataSet(x, y, test_transforms)
- transforms.Resize()只能对PIL.Image.open()打开的图片修改大小。 用io.imread读取图片时 (img = io.imread(the_path+""+class_path+""+img_path)), 不能使用transforms.Resize()修改读入的图片大小。而图片转为tensor后,tensor的大小需要一致,所以这里使用PIL.Image来读取图片。
- 图片标签为字符串类型,需转为torch.Tensor类型。然而,元素是字符串的list, tuple等不能直接转为torch.Tensor类型。解决方法: (1)使用sklearn中的preprocessing
from sklearn import preprocessing import torch # 可以用ClassToList(data+"\test")从文件中读出所有类别 labels = [000000, 000001, 000002, ..., 000305] le = preprocessing.LabelEncoder() targets = le.fit_transform(labels) # targets: array([0, 1, 2, ..., 305]) targets = torch.as_tensor(targets) # targets: tensor([0, 1, 2, ..., 305])
(2)将字符串标签存在list中,用标签对应的索引替换字符串标签。
# 可以用ClassToList(data+"\test")从文件中读出所有类别 class306 = [000000, 000001, 000002, ..., 000305] j = 0 for class_path in os.listdir(r.Dataset est): for each_img in os.listdir(r.Dataset est\+class_path): targets.append(j) # 放入图片标签的索引 j += 1 targets = torch.tensor(targets) # 转为torch.Tensor类型
- transforms.Resize()需放在transforms.ToTensor()前面,否则无效。图片需要在转为Tensor类型前修改大小。
# 用于处理读入的图片格式 img_transforms = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=stdv), ])
ClassToList():
def ClassToList(file_path): class306 = [] for class_path in os.listdir(file_path): class306.append(class_path) return class306
MyDataSet():
# 继承自 Dataset class MyDataSet(torch.utils.data.Dataset): def __init__(self,x,y,transform): self.x = x self.y = y if not isinstance(y, torch.Tensor): print("将y从", type(y),end=) self.y = torch.tensor(y) print("转化为", type(self.y)) else: self.y = y self.idx = list() self.transform = transform for item in x: self.idx.append(item) def __getitem__(self, index): input_data = self.idx[index] target = self.y[index] return input_data, target def __len__(self): return len(self.idx)
下一篇:
Python各版本安装包下载