Pytorch框架之tensor类型转换(type, type_as)

type – 指定类型改变

原型:type(dtype=None, non_blocking=False, **kwargs) 按输入的数据类型进行转换并返回 【注意: 如果不指定数据类型,就返回本身的数据】

data = torch.ones(2, 2)
print(data.dtype)
#result: torch.int64
# 可能在操作过程中指定其他数据类型--这里就按照ones--对应int64类型
data = data.type(torch.float32)  # 要接收类型已经改变的tensor数据,否则data本身是不会直接改变数据类型的
print(data.dtype)
#result: torch.float32

type_as --按照给定的tensor的类型转换类型

原型:type_as(tensor) 按给定的tensor确定转换的数据类型–如果类型相同则不做改变–否则改为传入的tensor类型–并返回类型改变的tensor数据。

data = torch.ones(2, 2)
data_float = torch.randn(2, 2)  # 这里的数据类型为torch.float64
print(data.dtype)
#result: torch.int64
# 可能在操作过程中指定其他数据类型--这里就按照ones--对应int64类型
data = data.type_as(data_float )
print(data.dtype)
#result: torch.float64
经验分享 程序员 微信小程序 职场和发展