pytorch深度学习实践-加载数据集0106
B站 刘二大人:
1、代码实现Mini-Batch进行训练
import torch import numpy as np from torch.utils.data import Dataset # 抽象类 from torch.utils.data import DataLoader # 帮助加载数据 # prepare dataset class DiabetesDataset(Dataset): # 继承自Dataset def __init__(self, filepath): # filepath文件来自什么地方 xy = np.loadtxt(filepath, delimiter=,, dtype=np.float32) self.len = xy.shape[0] # shape(多少行,多少列),读取N数 self.x_data = torch.from_numpy(xy[:, :-1]) self.y_data = torch.from_numpy(xy[:, [-1]]) def __getitem__(self, index): # 魔法方法 return self.x_data[index], self.y_data[index] def __len__(self): return self.len dataset = DiabetesDataset(diabetes.csv) train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2) # shuffle:打乱顺序,num_workers 多线程,几个并行进程 # design model using class class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear1 = torch.nn.Linear(8, 6) self.linear2 = torch.nn.Linear(6, 4) self.linear3 = torch.nn.Linear(4, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.linear1(x)) x = self.sigmoid(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x model = Model() # construct loss and optimizer criterion = torch.nn.BCELoss(reduction=mean) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # training cycle forward, backward, update if __name__ == __main__: for epoch in range(100): for i, data in enumerate(train_loader, 0): # train_loader取出(x,y)放入data,从0开始枚举,train_loader 是先shuffle后mini_batch inputs, labels = data # inputs-x, labels-y y_pred = model(inputs) loss = criterion(y_pred, labels) print(epoch, i, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()
2、对for i, data in enumerate(trainloader, 0): 的解释
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
enumerate(sequence, [start=0])
-
sequence -- 一个序列、迭代器或其他支持迭代对象。 start -- 下标起始位置。
>>>seq = [one, two, three] >>> for i, element in enumerate(seq): ... print i, element ... 0 one 1 two 2 three
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
单点登录的三种实现方式