在 pytorch 中,nn.module
是神经网络中最核心的基类,用于构建所有模型。理解并熟练使用 nn.module
是掌握 pytorch 的关键。
一、什么是nn.module
nn.module
是 pytorch 中所有神经网络模块的基类。可以把它看作是“神经网络的容器”,它封装了以下几件事:
- 网络层(如 linear、conv2d 等)
- 前向传播逻辑(
forward
函数) - 模型参数(自动注册并可训练)
- 可嵌套(可以包含多个子模块)
- 便捷的模型保存 / 加载等工具函数
二、基础用法
2.1 自定义模型类
import torch import torch.nn as nn class mynet(nn.module): def __init__(self): super().__init__() self.fc1 = nn.linear(784, 128) self.relu = nn.relu() self.fc2 = nn.linear(128, 10) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x
2.2 实例化与调用
model = mynet() x = torch.randn(32, 784) # batch_size = 32 output = model(x) # 自动调用 forward
三、构造方法详解
3.1__init__()
- 定义子模块、层等结构。
- 例如
self.conv1 = nn.conv2d(...)
会被自动注册为模型参数。
3.2forward()
- 定义前向传播逻辑。
- 不能手动调用,应使用
model(x)
形式。
四、常见模块层
模块名 | 作用 | 示例 |
---|---|---|
nn.linear | 全连接层 | nn.linear(128, 64) |
nn.conv2d | 卷积层 | nn.conv2d(3, 16, 3) |
nn.relu | 激活函数 | nn.relu() |
nn.sigmoid | 激活函数 | nn.sigmoid() |
nn.batchnorm2d | 批归一化 | nn.batchnorm2d(16) |
nn.dropout | dropout 层 | nn.dropout(0.5) |
nn.lstm | lstm 层 | nn.lstm(10, 20) |
nn.sequential | 层的顺序容器 | 见下文说明 |
五、模型嵌套结构(子模块)
你可以将一个 nn.module
作为另一个模块的子模块嵌套:
class block(nn.module): def __init__(self): super().__init__() self.layer = nn.sequential( nn.linear(64, 64), nn.relu() ) def forward(self, x): return self.layer(x) class net(nn.module): def __init__(self): super().__init__() self.block1 = block() self.block2 = block() self.output = nn.linear(64, 10) def forward(self, x): x = self.block1(x) x = self.block2(x) return self.output(x)
六、内置方法和属性
方法 / 属性 | 说明 |
---|---|
model.parameters() | 返回所有可训练参数(用于优化器) |
model.named_parameters() | 返回带名字的参数迭代器 |
model.children() | 返回子模块迭代器 |
model.eval() | 设置为评估模式(dropout、bn失效) |
model.train() | 设置为训练模式 |
model.to(device) | 将模型转移到 gpu/cpu |
model.state_dict() | 获取模型参数字典(保存) |
model.load_state_dict() | 加载模型参数字典 |
七、使用nn.sequential
nn.sequential
是一个顺序容器,可以用来简化网络结构定义:
model = nn.sequential( nn.linear(784, 128), nn.relu(), nn.linear(128, 10) )
等价于手写的自定义 nn.module
。适合前向传播是线性“流动”的结构。
八、实战完整示例:mnist 分类网络
class mnistnet(nn.module): def __init__(self): super().__init__() self.net = nn.sequential( nn.flatten(), nn.linear(28*28, 256), nn.relu(), nn.linear(256, 10) ) def forward(self, x): return self.net(x) # 实例化模型 model = mnistnet() print(model) # 配置训练 criterion = nn.crossentropyloss() optimizer = torch.optim.adam(model.parameters(), lr=1e-3) # 示例训练循环 for epoch in range(10): for images, labels in train_loader: output = model(images) loss = criterion(output, labels) optimizer.zero_grad() loss.backward() optimizer.step()
九、常见陷阱和建议
问题 | 说明 |
---|---|
forward() 不起作用 | 应该使用 model(x) ,而不是手动调用 model.forward(x) |
忘记 super().__init__() | 子模块将不会被注册 |
参数未注册 | 层/模块必须赋值为 self.xxx = ... |
训练/测试模式混淆 | 注意 model.eval() 和 model.train() |
十、总结
项目 | 说明 |
---|---|
__init__() | 定义模型结构(子模块、层) |
forward() | 定义前向传播 |
自动注册参数 | 所有 self.xxx = nn.xxx(...) 都会被追踪 |
嵌套模块 | 支持递归子模块调用 |
便捷方法 | .parameters() 、.to() 、.eval() 等 |
十一、综合示例
以下是基于 pytorch nn.module
封装的三种经典深度学习架构(resnet18、unet、transformer)的简洁而完整的实现,适合初学者快速上手。
1、resnet18 简洁实现(适合图像分类)
import torch import torch.nn as nn import torch.nn.functional as f class basicblock(nn.module): expansion = 1 def __init__(self, in_planes, planes, stride=1, downsample=none): super().__init__() self.conv1 = nn.conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=false) self.bn1 = nn.batchnorm2d(planes) self.conv2 = nn.conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=false) self.bn2 = nn.batchnorm2d(planes) self.downsample = downsample def forward(self, x): identity = x if self.downsample: identity = self.downsample(x) out = f.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += identity return f.relu(out) class resnet(nn.module): def __init__(self, block, layers, num_classes=1000): super().__init__() self.in_planes = 64 self.conv1 = nn.conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=false) self.bn1 = nn.batchnorm2d(64) self.pool = nn.maxpool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.adaptiveavgpool2d((1, 1)) self.fc = nn.linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, blocks, stride=1): downsample = none if stride != 1 or self.in_planes != planes * block.expansion: downsample = nn.sequential( nn.conv2d(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=false), nn.batchnorm2d(planes * block.expansion) ) layers = [block(self.in_planes, planes, stride, downsample)] self.in_planes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.in_planes, planes)) return nn.sequential(*layers) def forward(self, x): x = self.pool(f.relu(self.bn1(self.conv1(x)))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x).flatten(1) return self.fc(x) def resnet18(num_classes=1000): return resnet(basicblock, [2, 2, 2, 2], num_classes)
2、unet(适合图像分割)
class unetblock(nn.module): def __init__(self, in_ch, out_ch): super().__init__() self.block = nn.sequential( nn.conv2d(in_ch, out_ch, 3, padding=1), nn.relu(inplace=true), nn.conv2d(out_ch, out_ch, 3, padding=1), nn.relu(inplace=true) ) def forward(self, x): return self.block(x) class unet(nn.module): def __init__(self, in_channels=1, out_channels=1): super().__init__() self.enc1 = unetblock(in_channels, 64) self.enc2 = unetblock(64, 128) self.enc3 = unetblock(128, 256) self.enc4 = unetblock(256, 512) self.pool = nn.maxpool2d(2) self.bottleneck = unetblock(512, 1024) self.upconv4 = nn.convtranspose2d(1024, 512, 2, stride=2) self.dec4 = unetblock(1024, 512) self.upconv3 = nn.convtranspose2d(512, 256, 2, stride=2) self.dec3 = unetblock(512, 256) self.upconv2 = nn.convtranspose2d(256, 128, 2, stride=2) self.dec2 = unetblock(256, 128) self.upconv1 = nn.convtranspose2d(128, 64, 2, stride=2) self.dec1 = unetblock(128, 64) self.final = nn.conv2d(64, out_channels, kernel_size=1) def forward(self, x): e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) e4 = self.enc4(self.pool(e3)) b = self.bottleneck(self.pool(e4)) d4 = self.upconv4(b) d4 = self.dec4(torch.cat([d4, e4], dim=1)) d3 = self.upconv3(d4) d3 = self.dec3(torch.cat([d3, e3], dim=1)) d2 = self.upconv2(d3) d2 = self.dec2(torch.cat([d2, e2], dim=1)) d1 = self.upconv1(d2) d1 = self.dec1(torch.cat([d1, e1], dim=1)) return self.final(d1)
3、简化版 transformer 编码器(适合序列建模)
class transformerblock(nn.module): def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1): super().__init__() self.attn = nn.multiheadattention(embed_dim, heads, dropout=dropout, batch_first=true) self.ff = nn.sequential( nn.linear(embed_dim, ff_hidden_dim), nn.relu(), nn.linear(ff_hidden_dim, embed_dim) ) self.norm1 = nn.layernorm(embed_dim) self.norm2 = nn.layernorm(embed_dim) self.dropout = nn.dropout(dropout) def forward(self, x, mask=none): attn_out, _ = self.attn(x, x, x, attn_mask=mask) x = self.norm1(x + self.dropout(attn_out)) ff_out = self.ff(x) x = self.norm2(x + self.dropout(ff_out)) return x class transformerencoder(nn.module): def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512): super().__init__() self.embedding = nn.embedding(vocab_size, embed_dim) self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim) self.layers = nn.modulelist([ transformerblock(embed_dim, n_heads, ff_dim) for _ in range(num_layers) ]) self.dropout = nn.dropout(0.1) def _generate_positional_encoding(self, max_len, d_model): pos = torch.arange(0, max_len).unsqueeze(1) i = torch.arange(0, d_model, 2) angle_rates = 1 / torch.pow(10000, (i / d_model)) pos_enc = torch.zeros(max_len, d_model) pos_enc[:, 0::2] = torch.sin(pos * angle_rates) pos_enc[:, 1::2] = torch.cos(pos * angle_rates) return pos_enc.unsqueeze(0) def forward(self, x): b, t = x.shape x = self.embedding(x) + self.pos_encoding[:, :t].to(x.device) x = self.dropout(x) for layer in self.layers: x = layer(x) return x
4、 总结对比
模型类型 | 场景 | 特点 |
---|---|---|
resnet18 | 图像分类 | 深残差网络结构,适合迁移学习 |
unet | 图像分割 | 对称结构,编码 + 解码 + skip |
transformer | nlp / 序列建模 | 全注意力机制,无卷积无循环 |
到此这篇关于pytorch中nn.module详解和综合代码示例的文章就介绍到这了,更多相关pytorch nn.module内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论