理解transforms.ToTensor()函数
先说结论:
transforms.ToTensor()的操作对象有PIL格式的图像以及numpy(即cv2读取的图像也可以)这两种。对象不能是tensor格式的,因为是要转换为tensor的。
附上源码:
class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. """ def __call__(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.to_tensor(pic) def __repr__(self): return self.__class__.__name__ + () ###################### F.to_tensor(pic)函数 def to_tensor(pic): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. See ``ToTensor`` for more details. Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ if not(_is_pil_image(pic) or _is_numpy(pic)): raise TypeError(pic should be PIL Image or ndarray. Got {}.format(type(pic))) if _is_numpy(pic) and not _is_numpy_image(pic): raise ValueError(pic should be 2/3 dimensional. Got {} dimensions..format(pic.ndim)) if isinstance(pic, np.ndarray): # handle numpy array if pic.ndim == 2: pic = pic[:, :, None] img = torch.from_numpy(pic.transpose((2, 0, 1))) # backward compatibility if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) return torch.from_numpy(nppic) # handle PIL Image if pic.mode == I: img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == I;16: img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == F: img = torch.from_numpy(np.array(pic, np.float32, copy=False)) elif pic.mode == 1: img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format img = img.permute((2, 0, 1)).contiguous() if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img
下面是一些操作的示例:
import torchvision.transforms as trans from PIL import Image import cv2 trans1 = trans.ToTensor() img = Image.open(xxx.jpg) print(img.size) # w, h img_PIL_tensor = trans1(img) print(img_PIL_tensor.size()) # c, h, w img = cv2.imread(xxx.jpg) print(img.shape) # h, w, c img_cv2_tensor = trans1(img) print(img_cv2_tensor.size()) # c, h, w img = np.zeros([100,200,3]) # h, w, c img_np_tensor = trans1(img) print(img_np_tensor.size()) # c, h, w torch.Size([3, 100, 200])
上一篇:
JS实现多线程数据分片下载