理解transforms.ToTensor()函数
先说结论:
transforms.ToTensor()的操作对象有PIL格式的图像以及numpy(即cv2读取的图像也可以)这两种。对象不能是tensor格式的,因为是要转换为tensor的。
附上源码:
class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8
    In the other cases, tensors are returned without scaling.
    """
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)
    def __repr__(self):
        return self.__class__.__name__ + ()
    
###################### F.to_tensor(pic)函数 
def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    See ``ToTensor`` for more details.
    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
    Returns:
        Tensor: Converted image.
    """
    if not(_is_pil_image(pic) or _is_numpy(pic)):
        raise TypeError(pic should be PIL Image or ndarray. Got {}.format(type(pic)))
    if _is_numpy(pic) and not _is_numpy_image(pic):
        raise ValueError(pic should be 2/3 dimensional. Got {} dimensions..format(pic.ndim))
    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)
    # handle PIL Image
    if pic.mode == I:
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == I;16:
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    elif pic.mode == F:
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
    elif pic.mode == 1:
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1)).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img 
下面是一些操作的示例:
import torchvision.transforms as trans from PIL import Image import cv2 trans1 = trans.ToTensor() img = Image.open(xxx.jpg) print(img.size) # w, h img_PIL_tensor = trans1(img) print(img_PIL_tensor.size()) # c, h, w img = cv2.imread(xxx.jpg) print(img.shape) # h, w, c img_cv2_tensor = trans1(img) print(img_cv2_tensor.size()) # c, h, w img = np.zeros([100,200,3]) # h, w, c img_np_tensor = trans1(img) print(img_np_tensor.size()) # c, h, w torch.Size([3, 100, 200])
上一篇:
			            JS实现多线程数据分片下载 
			          
			          
			        