pytorch深度学习实践-加载数据集0106

B站 刘二大人:

1、代码实现Mini-Batch进行训练

import torch
import numpy as np
from torch.utils.data import Dataset  # 抽象类
from torch.utils.data import DataLoader  # 帮助加载数据


# prepare dataset


class DiabetesDataset(Dataset):  # 继承自Dataset
    def __init__(self, filepath):  # filepath文件来自什么地方
        xy = np.loadtxt(filepath, delimiter=,, dtype=np.float32)
        self.len = xy.shape[0]  # shape(多少行,多少列),读取N数
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):  # 魔法方法
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


dataset = DiabetesDataset(diabetes.csv)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
# shuffle:打乱顺序,num_workers 多线程,几个并行进程


# design model using class

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()

# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction=mean)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# training cycle forward, backward, update
if __name__ == __main__:
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):
            # train_loader取出(x,y)放入data,从0开始枚举,train_loader 是先shuffle后mini_batch
            inputs, labels = data  # inputs-x, labels-y
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

2、对for i, data in enumerate(trainloader, 0): 的解释

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

enumerate(sequence, [start=0])
    sequence -- 一个序列、迭代器或其他支持迭代对象。 start -- 下标起始位置。
>>>seq = [one, two, three]
>>> for i, element in enumerate(seq):
...     print i, element
... 
0 one
1 two
2 three
经验分享 程序员 微信小程序 职场和发展