每天讲解一点PyTorch 【17】Spatial Affinity代码实现分析
>>> import torch >>> import torch.nn as nn >>> fm_in = torch.randn(1,3,2,3) >>> fm_in tensor([[[[-0.1291, -0.0966, 0.0632], [-0.1640, -0.2832, 1.0553]], [[ 1.2854, 0.3400, 1.6823], [ 0.1555, -1.2052, 0.1288]], [[ 0.5609, 0.3766, 0.7720], [-2.0410, 0.2177, 1.4301]]]]) >>> >>> fm_in.shape torch.Size([1, 3, 2, 3]) >>> >>> fm_in = fm_in.view(fm_in.size(0), fm_in.size(1), -1) >>> fm_in tensor([[[-0.1291, -0.0966, 0.0632, -0.1640, -0.2832, 1.0553], [ 1.2854, 0.3400, 1.6823, 0.1555, -1.2052, 0.1288], [ 0.5609, 0.3766, 0.7720, -2.0410, 0.2177, 1.4301]]]) >>> fm_in.shape torch.Size([1, 3, 6]) >>> >>> pow_out = torch.pow(fm_in,2) >>> pow_out tensor([[[1.6674e-02, 9.3340e-03, 3.9945e-03, 2.6887e-02, 8.0227e-02, 1.1137e+00], [1.6523e+00, 1.1559e-01, 2.8300e+00, 2.4166e-02, 1.4526e+00, 1.6588e-02], [3.1456e-01, 1.4179e-01, 5.9594e-01, 4.1658e+00, 4.7385e-02, 2.0452e+00]]]) >>> pow_out.shape torch.Size([1, 3, 6]) >>> >>> sum_out = torch.sum(pow_out, 1) >>> sum_out tensor([[1.9836, 0.2667, 3.4300, 4.2168, 1.5802, 3.1755]]) >>> sum_out.shape torch.Size([1, 6]) >>> >>> sqrt_out = torch.sqrt(sum_out) >>> sqrt_out tensor([[1.4084, 0.5164, 1.8520, 2.0535, 1.2571, 1.7820]]) >>> sqrt_out.shape torch.Size([1, 6]) >>> >>> unsqueeze_out = sqrt_out.unsqueeze(1) >>> unsqueeze_out tensor([[[1.4084, 0.5164, 1.8520, 2.0535, 1.2571, 1.7820]]]) >>> unsqueeze_out.shape torch.Size([1, 1, 6]) >>> >>> expand_out = unsqueeze_out.expand(fm_in.shape) >>> expand_out tensor([[[1.4084, 0.5164, 1.8520, 2.0535, 1.2571, 1.7820], [1.4084, 0.5164, 1.8520, 2.0535, 1.2571, 1.7820], [1.4084, 0.5164, 1.8520, 2.0535, 1.2571, 1.7820]]]) >>> >>> fm_in.shape torch.Size([1, 3, 6]) >>> expand_out.shape torch.Size([1, 3, 6]) >>> >>> norm_fm_in = fm_in / (expand_out + 0.0000001) >>> norm_fm_in tensor([[[-0.0917, -0.1871, 0.0341, -0.0799, -0.2253, 0.5922], [ 0.9127, 0.6583, 0.9083, 0.0757, -0.9588, 0.0723], [ 0.3982, 0.7291, 0.4168, -0.9939, 0.1732, 0.8025]]]) >>> fm_in.shape torch.Size([1, 3, 6]) >>> norm_fm_in.shape torch.Size([1, 3, 6]) >>> >>> transpose_out = norm_fm_in.transpose(1,2) >>> transpose_out tensor([[[-0.0917, 0.9127, 0.3982], [-0.1871, 0.6583, 0.7291], [ 0.0341, 0.9083, 0.4168], [-0.0799, 0.0757, -0.9939], [-0.2253, -0.9588, 0.1732], [ 0.5922, 0.0723, 0.8025]]]) >>> transpose_out.shape torch.Size([1, 6, 3]) >>> >>> fm_out = transpose_out.bmm(norm_fm_in) >>> >>> fm_out tensor([[[ 1.0000, 0.9083, 0.9919, -0.3194, -0.7854, 0.3313], [ 0.9083, 1.0000, 0.8955, -0.6599, -0.4628, 0.5219], [ 0.9919, 0.8955, 1.0000, -0.3483, -0.8064, 0.4204], [-0.3194, -0.6599, -0.3483, 1.0000, -0.2267, -0.8395], [-0.7854, -0.4628, -0.8064, -0.2267, 1.0000, -0.0638], [ 0.3313, 0.5219, 0.4204, -0.8395, -0.0638, 1.0000]]]) >>> fm_out.shape torch.Size([1, 6, 6]) >>> >>> fm_out = fm_out.unsqueeze(1) >>> fm_out.shape torch.Size([1, 1, 6, 6]) >>> >>> fm_out tensor([[[[ 1.0000, 0.9083, 0.9919, -0.3194, -0.7854, 0.3313], [ 0.9083, 1.0000, 0.8955, -0.6599, -0.4628, 0.5219], [ 0.9919, 0.8955, 1.0000, -0.3483, -0.8064, 0.4204], [-0.3194, -0.6599, -0.3483, 1.0000, -0.2267, -0.8395], [-0.7854, -0.4628, -0.8064, -0.2267, 1.0000, -0.0638], [ 0.3313, 0.5219, 0.4204, -0.8395, -0.0638, 1.0000]]]]) >>>
上一篇:
JS实现多线程数据分片下载