torch.squeeze和torch.unsqueeze用法
一、torch.squeeze用法
torch.squeeze可以对tensor的维度进行压缩。
>>> d = torch.arange(6).view(1,2,1,3) >>> d tensor([[[[0, 1, 2]], [[3, 4, 5]]]]) >>> d.shape torch.Size([1, 2, 1, 3])
不加参数直接调用torch.squeeze函数
>>> d.squeeze() tensor([[0, 1, 2], [3, 4, 5]]) >>> d.squeeze().shape torch.Size([2, 3])
可以看出,维度为(1,2,1,3)直接变为了(2,3)。即去掉了维度为1的所有维度。
调用torch.squeeze的同时传入参数
我们依次传入参数0,1,2,3。0指向第一个维度,1指向第二个维度,同理,以此类推。由于tensor的shape是(1,2,1,3)。所以知识第二个维度和第四个维度为1.所以传入的参数可以是1或者3.这样就可以消除对应位置的维度。
>>> d.shape torch.Size([1, 2, 1, 3]) >>> d.squeeze(0).shape torch.Size([2, 1, 3]) >>> d.squeeze(1).shape torch.Size([1, 2, 1, 3]) >>> d.squeeze(2).shape torch.Size([1, 2, 3]) >>> d.squeeze(3).shape torch.Size([1, 2, 1, 3])
二、torch.unsqueeze函数的用法
构造一个4 x 5 的tensor
>>> a = torch.arange(20).view(4,5) >>> a tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]) >>> a.shape torch.Size([4, 5])
a.unsqueeze(0) 在索引0对应位置增加一个维度
>>> a.unsqueeze(0) tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]]) >>> a.unsqueeze(0).shape torch.Size([1, 4, 5])
a.unsqueeze(1) 在索引1对应位置增加一个维度
>>> a.unsqueeze(1) tensor([[[ 0, 1, 2, 3, 4]], [[ 5, 6, 7, 8, 9]], [[10, 11, 12, 13, 14]], [[15, 16, 17, 18, 19]]]) >>> a.unsqueeze(1).shape torch.Size([4, 1, 5])
a.unsqueeze(2) 在索引2对应位置增加一个维度
>>> a.unsqueeze(2) tensor([[[ 0], [ 1], [ 2], [ 3], [ 4]], [[ 5], [ 6], [ 7], [ 8], [ 9]], [[10], [11], [12], [13], [14]], [[15], [16], [17], [18], [19]]]) >>> a.unsqueeze(2).shape torch.Size([4, 5, 1])
a.unsqueeze(3) 会把报错。这是因为超出了范围
>>> a.unsqueeze(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
上一篇:
JS实现多线程数据分片下载
下一篇:
如何录制微课?教师必看