深度学习09图片数据集
#图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
#并除以225使得所有像素的数值均在0和1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=True)
print(len(mnist_train),len(mnist_test))
print(mnist_train[0][0].shape)
def get_fashion_mnist_labels(labels):
#返回Fashion-MNIST数据集的文本标签
text_labels = [t-shirt,trouser,pullover,dress,coat,
sandal,shirt,sneaker,bag,ankle boot]
return [text_labels[int(i)] for i in labels]
def show_images(imgs,num_rows,num_cols,title=None,scale=1.5):
#plot a list of images
figsize = (num_cols*scale,num_rows*scale)
_,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize)
axes = axes.flatten()
for i,(ax,img) in enumerate(zip(axes,imgs)):
ax.set_title(title[i])
if torch.is_tensor(img):
#图片张量
ax.imshow(img.numpy())
else:
#PIL图片
ax.imshow(img)
x,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
show_images(x.reshape(18,28,28),2,9,title=get_fashion_mnist_labels(y))
#d2l.plt.show()
batch_size = 256
def get_dataloader_workers():
#使用4个进程来读取数据
return 4
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for x,y in train_iter:
continue
print(f{timer.stop():.2f} sec)
def load_data_fashion_mnist(batch_size,resize=None):
#下载Fashion-MNIST数据集,然后将其加载到内存中
trans = [transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root="./data",
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data",
train=False,
transform=trans,
download=True)
return (data.DataLoader(mnist_train,batch_size,shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test,batch_size,shuffle=True,
num_workers=get_dataloader_workers()))
上一篇:
JS实现多线程数据分片下载
