当前位置: 代码网 > it编程>前端脚本>Python > pytorch中nn.module如何自动调用forward()方法

pytorch中nn.module如何自动调用forward()方法

2024年08月01日 Python 我要评论
nn.Module自动调用forward方法的内部实现原理以及源码分析

当我们在使用pytorch的时候,观察以下例子:

from torch import nn
import torch

class net(nn.module):
    def __init__(self):
        print('init')
        super().__init__()
    
    def forward(self, input):
        output = input + 1
        print("forward")
        return output

net = net()
x = torch.tensor(1.0)
out = net(x)
print(out)

输出:
init
forward
tensord(2.)

定义一个神经网络net后实现forward方法,后续实例化为net网络后直接传入x即可自动调用forward方法,可明明我们并未调用forward()方法,具体如何实现的呢?

原因在于net继承于nn.module,而nn.module内部实现了_call_()方法

在python中,_call_() 是一个特殊方法(也称为魔术方法或魔法方法),用于使对象实例可以像函数一样被调用。当一个类定义了 _call_() 方法时,它的实例可以像调用函数一样被调用,而不仅仅是通过类中的其他方法来调用。如下例子:

class test():
    def __init__(self):
        print('init')

    def __call__(self,a):
      print('call')
      print(a)


test = test()
print(test('test'))

输出:
init
call
test

而在pytorch中,_call_()方法会自动调用forward()方法。

具体说:当你调用net(x)后,net会自动调用_call_()方法, 相当于net._call_(x),而_call_()又会自动调用forward()方法。

相当于:net(x) --> net._call_(x) -->net.forward(x)

了解到这里基本够了,后面是pytorch的源码分析,可以不看。

pytorch源码分析:

from typing import callable

class module():
  __call__ : callable[..., any] = _wrapped_call_impl

可以看到虽然有_call_()方法,但他并没有像我们上面写的那样定义一个def _call_()方法,分析一下这句话:

__call__ : callable[..., any] = _wrapped_call_impl

这里callable[…, any] 是一个类型提示,来源于type库,用来表示一个可调用对象的类型。类型提示是什么

类型提示是一种在函数参数、返回值以及变量上添加类型信息的注解,这些注解并不会影响运行时的行为,但可以被静态类型检查工具和ide用来提供更好的代码分析和错误检测。

你定义函数时 def fun(a:int) 这里的a:int 就是类型提示

让我们来分解这个类型提示的含义:

  1. […]: 这个省略号 … 表示可接受任意数量的参数,即函数或方法可以接受任意数量的参数,包括零个参数。

  2. any: 这个关键字表示函数或方法可以返回任意类型的值。

    综合起来,**_call_ : callable[…, any] 表示这个类型可以是一个接受任意数量参数的可调用对象,并且可以返回任意类型的值。**在类型提示中,这种方式用于表达灵活的函数签名,特别是当函数可能具有不固定参数数量或不确定返回类型时。

而后面的 = _wrapped_call_impl 相当于给这个_call_()方法取了一个别名, 接下来我们去找这个方法,源码如下:

  def _wrapped_call_impl(self, *args, **kwargs):
      if self._compiled_call_impl is not none:
          return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
      else:
          return self._call_impl(*args, **kwargs)

可以看到这里根据对象实例中的 _compiled_call_impl 属性是否不为 none执行不同的方法,

当_compiled_call_impl 不为none时,执行self._compiled_call_impl(*args, **kwargs)

否则执行self._call_impl(*args, **kwargs)

观看源码,可以发现 一般情况下,_compiled_call_impl 都是none的,除非你调用了module里的compile方法,否则都是执行self._call_impl(*args,**kwargs)的 ,先看看compile方法是干嘛的:

  def compile(self, *args, **kwargs):
      """
      compile this module's forward using :func:`torch.compile`.

      this module's `__call__` method is compiled and all arguments are passed as-is
      to :func:`torch.compile`.

      see :func:`torch.compile` for details on the arguments for this function.
      """
      self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

根据源码的注释,可以看到这段代码的使用torch.compile编译模块的_call_()方法,其效果是将当前模块的 _call_ 方法进行编译或者优化,并将优化后的实现保存在 _compiled_call_impl 属性中。编译后的实现可以提升执行效率或者改进其他方面的性能,具体取决于 torch.compile 函数的实现和参数设置。

如果你未调用过该compile方法,那么 _compiled_call_impl就是none的,那么_wrapped_call_impl()接下来就会执行_call_impl(), _call_impl()源码如下:

def _call_impl(self, *args, **kwargs):
    forward_call = (self._slow_forward if torch._c._get_tracing_state() else self.forward)
    # if we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
            or _global_backward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
        return forward_call(*args, **kwargs)

    try:
        result = none
    ...后面的太长省略

可以看到,分析forward_call = (self._slow_forward if torch._c._get_tracing_state() else self.forward)

torch._c._get_tracing_state() 函数用来检查当前是否处于追踪状态(tracing state)。如果处于追踪状态,说明在进行模型的图形化表示(例如在 torchscript 中),此时使用 _slow_forward 方法。

否则使用 self.forward 方法,至此终于知道__call__是怎么调用self.forward方法的了。

至于后面的代码就是输出一些日志啥的。

if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
        or _global_backward_pre_hooks or _global_backward_hooks
        or _global_forward_hooks or _global_forward_pre_hooks):
    return forward_call(*args, **kwargs)

意思是:如果当前模块没有任何的钩子(hooks),则跳过此方法中的其他逻辑,直接调用 forward_call 方法并返回其结果。
钩子(hooks)通常用于在模型运行过程中添加额外的操作或者记录信息,例如日志、梯度信息等。

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com