pytorch深度学习实践-加载数据集0106
B站 刘二大人:
1、代码实现Mini-Batch进行训练
import torch
import numpy as np
from torch.utils.data import Dataset # 抽象类
from torch.utils.data import DataLoader # 帮助加载数据
# prepare dataset
class DiabetesDataset(Dataset): # 继承自Dataset
def __init__(self, filepath): # filepath文件来自什么地方
xy = np.loadtxt(filepath, delimiter=,, dtype=np.float32)
self.len = xy.shape[0] # shape(多少行,多少列),读取N数
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index): # 魔法方法
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset(diabetes.csv)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
# shuffle:打乱顺序,num_workers 多线程,几个并行进程
# design model using class
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction=mean)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# training cycle forward, backward, update
if __name__ == __main__:
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# train_loader取出(x,y)放入data,从0开始枚举,train_loader 是先shuffle后mini_batch
inputs, labels = data # inputs-x, labels-y
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
2、对for i, data in enumerate(trainloader, 0): 的解释
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
enumerate(sequence, [start=0])
-
sequence -- 一个序列、迭代器或其他支持迭代对象。 start -- 下标起始位置。
>>>seq = [one, two, three] >>> for i, element in enumerate(seq): ... print i, element ... 0 one 1 two 2 three
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
单点登录的三种实现方式
