快捷搜索: 王者荣耀 脱发

Pytorch中的 torch.as_tensor() 和 torch.from_numpy() 的区别

之前我写过,比较了 torch.Tensor() 和 torch.tensor() 的区别,而这两者都是深拷贝的方法,返回张量的同时,会在内存中创建一个额外的数据副本,与原数据不共享内存,所以不受原数据改变的影响。

这里,介绍两种浅拷贝方法,torch.as_tensor() 和 torch.from_numpy(),同样还是从官方文档出发:

直接对比着来看,最明显的,torch.as_tensor() 接收三个参数(data, dtype=None, device=None),而 torch.from_numpy() 只接收一个参数,即 ndarray,正如前者文档中提到的:

如果 data 是具有相同 dtype 和 device 的 NumPy 数组(一个 ndarray),则使用 torch.from_numpy() 构造一个张量。

torch.from_numpy() 会根据输入的 ndarray 构造一个具有相同 dtype 与 device 的张量,这个行为跟 torch.as_tensor() 只输入ndarray,而 dtype 与 device 保持默认是一样的。从这个角度看,可以说 torch.from_numpy() 是 torch.as_tensor() 的一个特例。它们都是浅拷贝,这是因为张量(tensor)与数组(ndarray)存储共享相同的底层缓冲区,改变其中一个的值都会是另一个的值也被改变。

但是,torch.as_tensor() 显然适用性更广,它既可以接收非 ndarray 的数据,还能改变数据的 dtype 或 device,但这两个东西一旦被改变了,就会生成一个新的数据副本,此时 torch.as_tensor() 的行为就变成深拷贝了。

需要注意的是,Numpy 在 64 位机子上浮点数默认的数据类型是 float64,而 Pytorch 默认的是 float32。所以为了确保转换后的数据类型是 float32,以及兼顾适用性,使用 torch.as_tensor() 都是更好的选择。

经验分享 程序员 微信小程序 职场和发展