当我们在使用pytorch的时候,观察以下例子:
from torch import nn
import torch
class net(nn.module):
    def __init__(self):
        print('init')
        super().__init__()
    
    def forward(self, input):
        output = input + 1
        print("forward")
        return output
net = net()
x = torch.tensor(1.0)
out = net(x)
print(out)
输出:
init
forward
tensord(2.)
定义一个神经网络net后实现forward方法,后续实例化为net网络后直接传入x即可自动调用forward方法,可明明我们并未调用forward()方法,具体如何实现的呢?
原因在于net继承于nn.module,而nn.module内部实现了_call_()方法,
在python中,_call_() 是一个特殊方法(也称为魔术方法或魔法方法),用于使对象实例可以像函数一样被调用。当一个类定义了 _call_() 方法时,它的实例可以像调用函数一样被调用,而不仅仅是通过类中的其他方法来调用。如下例子:
class test():
    def __init__(self):
        print('init')
    def __call__(self,a):
      print('call')
      print(a)
test = test()
print(test('test'))
输出:
init
call
test
而在pytorch中,_call_()方法会自动调用forward()方法。
具体说:当你调用net(x)后,net会自动调用_call_()方法, 相当于net._call_(x),而_call_()又会自动调用forward()方法。
相当于:net(x) --> net._call_(x) -->net.forward(x)
了解到这里基本够了,后面是pytorch的源码分析,可以不看。
pytorch源码分析:
from typing import callable
class module():
  __call__ : callable[..., any] = _wrapped_call_impl
可以看到虽然有_call_()方法,但他并没有像我们上面写的那样定义一个def _call_()方法,分析一下这句话:
__call__ : callable[..., any] = _wrapped_call_impl
这里callable[…, any] 是一个类型提示,来源于type库,用来表示一个可调用对象的类型。类型提示是什么:
类型提示是一种在函数参数、返回值以及变量上添加类型信息的注解,这些注解并不会影响运行时的行为,但可以被静态类型检查工具和ide用来提供更好的代码分析和错误检测。
你定义函数时 def fun(a:int) 这里的a:int 就是类型提示
让我们来分解这个类型提示的含义:
-  […]: 这个省略号 … 表示可接受任意数量的参数,即函数或方法可以接受任意数量的参数,包括零个参数。 
-  any: 这个关键字表示函数或方法可以返回任意类型的值。 综合起来,**_call_ : callable[…, any] 表示这个类型可以是一个接受任意数量参数的可调用对象,并且可以返回任意类型的值。**在类型提示中,这种方式用于表达灵活的函数签名,特别是当函数可能具有不固定参数数量或不确定返回类型时。 
而后面的 = _wrapped_call_impl 相当于给这个_call_()方法取了一个别名, 接下来我们去找这个方法,源码如下:
  def _wrapped_call_impl(self, *args, **kwargs):
      if self._compiled_call_impl is not none:
          return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
      else:
          return self._call_impl(*args, **kwargs)
可以看到这里根据对象实例中的 _compiled_call_impl 属性是否不为 none执行不同的方法,
当_compiled_call_impl 不为none时,执行self._compiled_call_impl(*args, **kwargs)
否则执行self._call_impl(*args, **kwargs)
观看源码,可以发现 一般情况下,_compiled_call_impl 都是none的,除非你调用了module里的compile方法,否则都是执行self._call_impl(*args,**kwargs)的 ,先看看compile方法是干嘛的:
  def compile(self, *args, **kwargs):
      """
      compile this module's forward using :func:`torch.compile`.
      this module's `__call__` method is compiled and all arguments are passed as-is
      to :func:`torch.compile`.
      see :func:`torch.compile` for details on the arguments for this function.
      """
      self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)
根据源码的注释,可以看到这段代码的使用torch.compile编译模块的_call_()方法,其效果是将当前模块的 _call_ 方法进行编译或者优化,并将优化后的实现保存在 _compiled_call_impl 属性中。编译后的实现可以提升执行效率或者改进其他方面的性能,具体取决于 torch.compile 函数的实现和参数设置。
如果你未调用过该compile方法,那么 _compiled_call_impl就是none的,那么_wrapped_call_impl()接下来就会执行_call_impl(), _call_impl()源码如下:
def _call_impl(self, *args, **kwargs):
    forward_call = (self._slow_forward if torch._c._get_tracing_state() else self.forward)
    # if we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
            or _global_backward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
        return forward_call(*args, **kwargs)
    try:
        result = none
    ...后面的太长省略
可以看到,分析forward_call = (self._slow_forward if torch._c._get_tracing_state() else self.forward)
torch._c._get_tracing_state() 函数用来检查当前是否处于追踪状态(tracing state)。如果处于追踪状态,说明在进行模型的图形化表示(例如在 torchscript 中),此时使用 _slow_forward 方法。
否则使用 self.forward 方法,至此终于知道__call__是怎么调用self.forward方法的了。
至于后面的代码就是输出一些日志啥的。
if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
        or _global_backward_pre_hooks or _global_backward_hooks
        or _global_forward_hooks or _global_forward_pre_hooks):
    return forward_call(*args, **kwargs)
意思是:如果当前模块没有任何的钩子(hooks),则跳过此方法中的其他逻辑,直接调用 forward_call 方法并返回其结果。
 钩子(hooks)通常用于在模型运行过程中添加额外的操作或者记录信息,例如日志、梯度信息等。
 
             我要评论
我要评论 
                                             
                                             
                                            
发表评论