理解transforms.ToTensor()函数
先说结论:
transforms.ToTensor()的操作对象有PIL格式的图像以及numpy(即cv2读取的图像也可以)这两种。对象不能是tensor格式的,因为是要转换为tensor的。
附上源码:
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + ()
###################### F.to_tensor(pic)函数
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(pic should be PIL Image or ndarray. Got {}.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError(pic should be 2/3 dimensional. Got {} dimensions..format(pic.ndim))
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == I:
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == I;16:
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == F:
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == 1:
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
下面是一些操作的示例:
import torchvision.transforms as trans from PIL import Image import cv2 trans1 = trans.ToTensor() img = Image.open(xxx.jpg) print(img.size) # w, h img_PIL_tensor = trans1(img) print(img_PIL_tensor.size()) # c, h, w img = cv2.imread(xxx.jpg) print(img.shape) # h, w, c img_cv2_tensor = trans1(img) print(img_cv2_tensor.size()) # c, h, w img = np.zeros([100,200,3]) # h, w, c img_np_tensor = trans1(img) print(img_np_tensor.size()) # c, h, w torch.Size([3, 100, 200])
上一篇:
JS实现多线程数据分片下载
