torch.jit.trace 消除TracerWarning
在使用torch.jit.trace时,经常会碰到如下warning:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values,so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
当然这些warning,可能并不会在c++调用时产生错误,权作洁癖吧。
此博客汇总了个人尝试过的一些warning的破解方式:
1.慎用tensor.shape/torch.size()
1.1 生成新的tensor
如下:
y=x.new(x.size())
产生如下错误:
TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect.
可改成:
y=torch.zeros_like(x)
1.2 if/while语句中
在if/while语句中,有时需要用到tensor.shape信息,若如下操作:
if x.shape[0]: x*=2
报错:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
可修改为:
if x.numel(): x*=2
这里发现一点,在使用torch.jit.trace时,tensor.shape/tensor.size()某个维度的信息常被当做是tensor,如:
print(x.size(0)) #tensor(8)
而正常情况下只是一个int变量,这可能就是torch.jit.trace经常报类似错误的关键所在。
尚不清楚是bug,还是自己没用对。
2.比较两个单元素的tensor
import torch import torch.nn class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.a=torch.tensor(1) self.b=torch.tensor(2) def forward(self, x): if self.a!=self.b: x*=2 return x
报错:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
这里要使用torch.equal()或者tensor_A.equal(tensor_B).
import torch import torch.nn class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.a=torch.tensor(1) self.b=torch.tensor(2) def forward(self, x): if self.a.equal(self.b): x*=2 return x
3.使用参数strict=False
torch.jit.trace(model,input_imgs,strict=False)
如果你的模型输出不是以tensor的形式,而是如list等的形式,可以设置strict=False来消除warning。
strict (bool, optional) – run the tracer in a strict mode or not (default: True). Only turn this off when you want the tracer to record our mutable container types (currently list/dict) and you are sure that the container you are using in your problem is a constant structure and does not get used as control flow (if, for) conditions.
4.一般性的方法
-
不使用numpy; 变量尽可能用tensor形式。
5.其他
未充分验证:
使用 tensor.index_select。
参考文献
[1]
上一篇:
JS实现多线程数据分片下载
下一篇:
软件测试工程师发展方向,主要有哪些?