基本概念
class singleconv(nn.module): def __init__(self): super(singleconv, self).__init__() self.conv = nn.conv2d(3, 16, kernel_size=3, padding=1) def forward(self, x): return self.conv(x)
super(singleconv, self).__init__()
这行代码用于调用父类(基类)的初始化方法。具体来说,它让 singleconv
类调用其父类 nn.module
的 __init__()
方法,确保父类被正确初始化。
这是面向对象编程中继承机制的重要组成部分,特别是在构建 pytorch 神经网络模型时尤为关键。
语法分解
让我们逐部分解析这个表达式:
super()
- 这是 python 内置函数,用于返回一个代理对象,该对象将方法调用委托给父类或兄弟类。super(singleconv, self)
- 这部分创建了一个代理对象,指向singleconv
类的父类(在这个例子中是nn.module
)。第一个参数指定类本身,第二个参数通常是类的实例(即self
)。super(singleconv, self).__init__()
- 通过代理对象调用父类的__init__()
方法,确保父类的初始化代码被执行。
为什么这很重要?
一般继承原则
在继承关系中,子类需要确保父类被正确初始化,因为:
- 父类可能设置了子类依赖的重要属性和状态
- 父类可能执行了必要的初始化逻辑
- 如果不调用父类的初始化方法,继承链就会断开,子类将无法完全继承父类的功能
在 pytorch 中的特殊重要性
在 pytorch 的 nn.module
上下文中,调用 super().__init__()
尤为关键,因为:
参数管理 -
nn.module
的初始化方法设置了追踪和管理模型参数(如卷积层的权重和偏置)的机制模块注册 - 它建立了子模块的注册系统,使 pytorch 能够识别模型的层次结构
功能支持 - 它启用了许多核心功能,包括:
- 参数迁移(使用
.to(device)
将模型移动到 cpu/gpu) - 模型保存和加载(使用
torch.save()
和torch.load()
) - 训练和评估模式切换(
.train()
和.eval()
) - 自动求导支持
- 参数迁移(使用
如果省略会发生什么?
如果您省略 super().__init__()
调用,可能会导致:
class singleconvwithoutsuper(nn.module): def __init__(self): # 没有调用 super().__init__() self.conv = nn.conv2d(3, 16, kernel_size=3, padding=1) def forward(self, x): return self.conv(x) model = singleconvwithoutsuper() print(list(model.parameters())) # 可能返回空列表,因为参数没有被正确注册
这样的模型会出现多种问题:
- 参数不会被正确注册和跟踪
- 无法正常使用
.to(device)
迁移到 gpu - 保存和加载模型时可能丢失参数
- 梯度可能无法正确传播
python 3 的简化语法
在 python 3 中,可以使用更简洁的语法:
class singleconv(nn.module): def __init__(self): super().__init__() # 简化版,等效于 super(singleconv, self).__init__() self.conv = nn.conv2d(3, 16, kernel_size=3, padding=1) def forward(self, x): return self.conv(x)
这种写法功能完全相同,但更简洁易读。python 3 的 super()
不带参数时会自动使用当前类和实例。
多重继承中的作用
在涉及多重继承的复杂情况下,super()
特别有用。它会按照方法解析顺序(mro)正确调用父类,避免同一个父类被初始化多次:
class a: def __init__(self): print("a init") class b(a): def __init__(self): super().__init__() print("b init") class c(a): def __init__(self): super().__init__() print("c init") class d(b, c): def __init__(self): super().__init__() print("d init")
当创建 d
的实例时,super()
确保每个父类的 __init__
只被调用一次,遵循 python 的 mro 规则。
执行结果
当我们创建 d
类的实例(如 d = d()
)时,输出结果为:
a init
c init
b init
d init
这个输出顺序可能看起来有些反直觉,特别是 c init
出现在 b init
之前,尽管在 d
的继承声明中 b
是第一个父类。让我们深入分析这是为什么。
方法解析顺序 (mro)
python 使用一种称为方法解析顺序(method resolution order, mro)的机制来确定多重继承中方法查找的顺序。mro 决定了当调用 super()
时,python 应该按照什么顺序查找父类的方法。
我们可以通过以下方式查看一个类的 mro:
print(d.__mro__) # 或者 print(d.mro())
对于我们的例子,d
类的 mro 是:
(<class '__main__.d'>, <class '__main__.b'>, <class '__main__.c'>, <class '__main__.a'>, <class 'object'>)
这意味着 python 在 d
的实例上查找方法时,会按照 d -> b -> c -> a -> object
的顺序搜索。
super() 的工作原理
理解 super()
的关键在于:super()
不是简单地调用父类的方法,而是调用 mro 中当前类之后的下一个类的方法。
当在一个类中使用 super().__init__()
时,python 会查找 mro 中当前类之后的下一个类,并调用其 __init__
方法。这是一个非常强大的机制,尤其是在处理复杂的继承结构时。
详细的执行流程
让我们逐步追踪 d()
创建实例时的执行流程:
调用 d.__init__()
- 执行
super().__init__()
- 根据 mro,
d
之后的类是b
,所以调用b.__init__()
- 执行
进入 b.__init__()
- 执行
super().__init__()
- 根据 mro,
b
之后的类是c
,所以调用c.__init__()
- 执行
进入 c.__init__()
- 执行
super().__init__()
- 根据 mro,
c
之后的类是a
,所以调用a.__init__()
- 执行
进入 a.__init__()
- 打印
"a init"
a.__init__()
执行完毕,返回到c.__init__()
- 打印
回到 c.__init__()
- 打印
"c init"
c.__init__()
执行完毕,返回到b.__init__()
- 打印
回到 b.__init__()
- 打印
"b init"
b.__init__()
执行完毕,返回到d.__init__()
- 打印
回到 d.__init__()
- 打印
"d init"
d.__init__()
执行完毕
- 打印
图解说明
下面是继承结构和执行顺序的图解:
a
/ \
b c
\ /
d
执行顺序(箭头表示调用方向):
d.__init__() → b.__init__() → c.__init__() → a.__init__() ↓ d.__init__() ← b.__init__() ← c.__init__() ← 返回并打印 "a init" ↓ ↓ ↓ ↓ ↓ 打印 "c init" ↓ 打印 "b init" 打印 "d init"
与直接调用父类方法的对比
为了理解 super()
的价值,让我们看看如果不使用 super()
而是直接调用父类的 __init__
方法会发生什么:
class a: def __init__(self): print("a init") class b(a): def __init__(self): a.__init__(self) # 直接调用 a.__init__ print("b init") class c(a): def __init__(self): a.__init__(self) # 直接调用 a.__init__ print("c init") class d(b, c): def __init__(self): b.__init__(self) # 直接调用 b.__init__ c.__init__(self) # 直接调用 c.__init__ print("d init")
使用这种方式,创建 d
的实例将输出:
a init # 从 b.__init__ 调用 b init a init # 从 c.__init__ 调用,a 被初始化了两次! c init d init
可以看到,a.__init__()
被调用了两次!这可能导致资源重复分配、状态不一致或其他问题。
总结
super(singleconv, self).__init__()
这行代码是确保 pytorch 神经网络模块正确初始化的关键步骤。它调用父类 nn.module
的初始化方法,设置必要的内部状态,并启用 pytorch 的核心功能。
在 python 3 中,推荐使用更简洁的 super().__init__()
语法。无论使用哪种形式,确保在每个继承自 nn.module
的类的 __init__
方法中调用它,这是构建正确功能的 pytorch 模型的基础。
简单来说,这行代码就像告诉您的类:“在我开始自己的初始化工作之前,请确保我从父类继承的所有功能都已正确设置好。”
到此这篇关于python中的super().__init__()用法详解的文章就介绍到这了,更多相关python中super().__init__()内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论