PyTorch函数中的__call__和forward函数
初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
init & call
代码:
class A(): def __init__(self): print(init函数) def __call__(self, param): print(call 函数, param) a = A()
输出 分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()
上面的代码加一行
class A(): def __init__(self): print(init函数) def __call__(self, param): print(call 函数, param) a = A() a(1)
输出 分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()
_ call_()中可以调用其它函数,如forward函数
class A(): def __init__(self): print(init函数) def __call__(self, param): print(call 函数, param) res = self.forward(param) return res + 2 def forward(self, input_): print(forward 函数, input_) return input_ a = A() b = a(1) print(结果b =,b)
分析:_call _()成功调用了forward(),且返回值给了b
nn.Module
看了上面的例子,就知道了_call _()的作用,那下面看更接近CNN的例子
from torch import nn import torch class Ding(nn.Module): def __init__(self): print(init) super().__init__() def forward(self, input): output = input + 1 print("forward") return output dzy = Ding() x = torch.tensor(1.0) out = dzy(x) print(out)
结果: 分析: 这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看 发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看
这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward() 所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn import torch class Ding(nn.Module): def __init__(self): print(init) super().__init__() def __call__(self, input_): print(重写call, 不用forward) return hhh def forward(self, input): output = input + 1 print("forward") return output dzy = Ding() x = torch.tensor(1.0) out = dzy(x) print(out)
总结
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。