No such operator image::read_file问题解决
在学习动手学深度学习这门课的时候,在13.6 节⽬标检测数据集这一章遇到了问题,读取数据的时候报错:No such operator image::read_file
网上有人说问题在于pytorch版本和torchvision版本不对应,可以通过重新安装,也可以通过找到pytorch和torchvision相对应的版本。我都试过了,没有用(至少对于我是这样的)。
再仔细观察,发现是torchvision.io.read_image()这个函数报错,这个函数调用了torch.ops.image.read_file(),正是后者出错。搞不懂这个函数是干什么的
不是办法的办法是,放弃使用torchvision.io.read_image(),使用Image.open()函数,原来的代码是
#@save def read_data_bananas(is_train=True): """读取香蕉检测数据集中的图像和标签""" data_dir = d2l.download_extract(banana-detection) csv_fname = os.path.join(data_dir, bananas_train if is_train else bananas_val, label.csv) csv_data = pd.read_csv(csv_fname) csv_data = csv_data.set_index(img_name) images, targets = [], [] for img_name, target in csv_data.iterrows(): images.append(torchvision.io.read_image( os.path.join(data_dir, bananas_train if is_train else bananas_val, images, f{img_name}))) # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y), # 其中所有图像都具有相同的香蕉类(索引为0) targets.append(list(target)) return images, torch.tensor(targets).unsqueeze(1) / 256
修改后的代码如下:
#@save def read_data_bananas(is_train=True): """读取香蕉检测数据集中的图像和标签""" # Image.open()读出来的图片是PIL格式,要转换为tensor格式 totensor = transforms.ToTensor() data_dir = d2l.download_extract(banana-detection) csv_fname = os.path.join(data_dir, bananas_train if is_train else bananas_val, label.csv) csv_data = pd.read_csv(csv_fname) csv_data = csv_data.set_index(img_name) images, targets = [], [] for img_name, target in csv_data.iterrows(): images.append(totensor(Image.open( os.path.join(data_dir, bananas_train if is_train else bananas_val, images, f{img_name})))) # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y), # 其中所有图像都具有相同的香蕉类(索引为0) targets.append(list(target)) return images, torch.tensor(targets).unsqueeze(1) / 256
具体来说,假设图片路径是path,将torchvision.io.read_image(path)替换为以下代码:
from PIL import Image import torchvision.transforms as transforms totensor = transforms.ToTensor() totensor(Image.open(path))
需要注意的是,transforms.ToTensor()处理后的图片像素位于0-1之间,不用再除以255,否则图片变成全黑色的了,并且画图的时候要改变通道位置(变成numpy格式)。下面画出数据集的第一张图片
import matplotlib.pyplot as plt img = batch[0][0].permute(1, 2, 0) plt.imshow(img)
上一篇:
JS实现多线程数据分片下载
下一篇:
申请专利需要具备什么条件