PyTorch 小功能之 TensorDataset
TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset import torch from torch.utils.data import DataLoader a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66]) train_ids = TensorDataset(a, b) # 切片输出 print(train_ids[0:2]) print(= * 80) # 循环取数据 for x_train, y_label in train_ids: print(x_train, y_label) # DataLoader进行数据封装 print(= * 80) train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True) for i, data in enumerate(train_loader, 1): # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签) x_data, label = data print( batch:{0} x_data:{1} label: {2}.format(i, x_data, label))
运行结果:
(tensor([[1, 2, 3], [4, 5, 6]]), tensor([44, 55])) ================================================================================ tensor([1, 2, 3]) tensor(44) tensor([4, 5, 6]) tensor(55) tensor([7, 8, 9]) tensor(66) tensor([1, 2, 3]) tensor(44) tensor([4, 5, 6]) tensor(55) tensor([7, 8, 9]) tensor(66) tensor([1, 2, 3]) tensor(44) tensor([4, 5, 6]) tensor(55) tensor([7, 8, 9]) tensor(66) tensor([1, 2, 3]) tensor(44) tensor([4, 5, 6]) tensor(55) tensor([7, 8, 9]) tensor(66) ================================================================================ batch:1 x_data:tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) label: tensor([44, 44, 55, 55]) batch:2 x_data:tensor([[4, 5, 6], [7, 8, 9], [7, 8, 9], [7, 8, 9]]) label: tensor([55, 66, 66, 66]) batch:3 x_data:tensor([[1, 2, 3], [1, 2, 3], [7, 8, 9], [4, 5, 6]]) label: tensor([44, 44, 66, 55])
注意:TensorDataset 中的参数必须是 tensor
上一篇:
通过多线程提高代码的执行效率例子