当前位置: 代码网 > it编程>前端脚本>Python > PyTorch中torch.nn模块的实现

PyTorch中torch.nn模块的实现

2024年09月23日 Python 我要评论
torch.nn是 pytorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:1.nn.module原理和要点:nn.module是所有神经网

torch.nn 是 pytorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:

1. nn.module

原理和要点nn.module 是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.module,并实现其 forward 方法。

使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。

主要 api 和使用场景

  • __init__: 初始化模型参数。
  • forward: 定义前向传播逻辑。
  • parameters: 返回模型的所有参数。
import torch
import torch.nn as nn

class mymodel(nn.module):
    def __init__(self):
        super(mymodel, self).__init__()
        self.linear = nn.linear(10, 1)
    
    def forward(self, x):
        return self.linear(x)

model = mymodel()
print(model)

2. layers(层)

  • 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
  • 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。

2.1 nn.linear(全连接层)

linear = nn.linear(10, 5)
input = torch.randn(1, 10)
output = linear(input)
print(output)

2.2 nn.conv2d(二维卷积层)

conv = nn.conv2d(in_channels=1, out_channels=3, kernel_size=3)
input = torch.randn(1, 1, 5, 5)
output = conv(input)
print(output)

2.3 nn.maxpool2d(二维最大池化层)

maxpool = nn.maxpool2d(kernel_size=2)
input = torch.randn(1, 1, 4, 4)
output = maxpool(input)
print(output)

3. loss functions(损失函数)

  • 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
  • 使用场景:用于计算训练过程中需要最小化的误差。

3.1 nn.mseloss(均方误差损失)

mse_loss = nn.mseloss()
input = torch.randn(3, 5)
target = torch.randn(3, 5)
loss = mse_loss(input, target)
print(loss)

3.2 nn.crossentropyloss(交叉熵损失)

cross_entropy_loss = nn.crossentropyloss()
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = cross_entropy_loss(input, target)
print(loss)

4. optimizers(优化器)

  • 原理和要点:优化器用于调整模型参数,以最小化损失函数。
  • 使用场景:用于训练模型,通过反向传播更新参数。

4.1 torch.optim.sgd(随机梯度下降)

import torch.optim as optim

model = mymodel()
optimizer = optim.sgd(model.parameters(), lr=0.01)
criterion = nn.mseloss()

# training loop
for epoch in range(100):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 1))
    loss.backward()
    optimizer.step()

4.2 torch.optim.adam(自适应矩估计)

optimizer = optim.adam(model.parameters(), lr=0.001)
# training loop
for epoch in range(100):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 1))
    loss.backward()
    optimizer.step()

5. activation functions(激活函数)

  • 原理和要点:激活函数引入非线性,使模型能够拟合复杂的函数。
  • 使用场景:用于激活输入,增加模型表达能力。

5.1 nn.relu(修正线性单元)

relu = nn.relu()
input = torch.randn(2)
output = relu(input)
print(output)

6. normalization layers(归一化层)

  • 原理和要点:归一化层用于标准化输入,改善训练的稳定性和速度。
  • 使用场景:用于标准化激活值,防止梯度爆炸或消失。

6.1 nn.batchnorm2d(二维批量归一化)

batch_norm = nn.batchnorm2d(3)
input = torch.randn(1, 3, 5, 5)
output = batch_norm(input)
print(output)

7. dropout layers(丢弃层)

  • 原理和要点:dropout 层通过在训练过程中随机丢弃一部分神经元来防止过拟合。
  • 使用场景:用于防止模型过拟合,增加模型的泛化能力。

7.1 nn.dropout

dropout = nn.dropout(p=0.5)
input = torch.randn(2, 3)
output = dropout(input)
print(output)

8. container modules(容器模块)

  • 原理和要点:容器模块用于组合多个层,构建复杂的神经网络结构。
  • 使用场景:用于组合多个层,形成更复杂的网络结构。

8.1 nn.sequential(顺序容器)

model = nn.sequential(
    nn.linear(10, 20),
    nn.relu(),
    nn.linear(20, 5)
)
input = torch.randn(1, 10)
output = model(input)
print(output)

8.2 nn.modulelist(模块列表)

layers = nn.modulelist([
    nn.linear(10, 20),
    nn.relu(),
    nn.linear(20, 5)
])

input = torch.randn(1, 10)
for layer in layers:
    input = layer(input)
print(input)

9. functional api (torch.nn.functional)

  • 原理和要点:包含大量用于深度学习的无状态函数,这些函数通常是操作层的底层实现。
  • 使用场景:用于在前向传播中灵活调用函数。

9.1 f.relu(relu 激活函数)

import torch.nn.functional as f

input = torch.randn(2)
output = f.relu(input)
print(output)

9.2 f.cross_entropy(交叉熵损失函数)

input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = f.cross_entropy(input, target)
print(loss)

9.3 f.conv2d(二维卷积)

input = torch.randn(1, 1, 5, 5)
weight = torch.randn(3, 1, 3, 3)  # manually defined weights
output = f.conv2d(input, weight)
print(output)

10. parameter (torch.nn.parameter)

  • 原理和要点torch.nn.parameter 是 torch.tensor 的一种特殊子类,用于表示模型的可学习参数。它们在 nn.module 中会自动注册为参数。
  • 使用场景:用于定义模型中的可学习参数。

示例代码:

class mymodelwithparam(nn.module):
    def __init__(self):
        super(mymodelwithparam, self).__init__()
        self.my_param = nn.parameter(torch.randn(10, 10))
    
    def forward(self, x):
        return x @ self.my_param

model = mymodelwithparam()
input = torch.randn(1, 10)
output = model(input)
print(output)

# 查看模型参数
for name, param in model.named_parameters():
    print(name, param.size())

综合示例

下面是一个结合上述各个部分的综合示例:

import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim

class mycomplexmodel(nn.module):
    def __init__(self):
        super(mycomplexmodel, self).__init__()
        self.conv1 = nn.conv2d(1, 32, kernel_size=3)
        self.bn1 = nn.batchnorm2d(32)
        self.conv2 = nn.conv2d(32, 64, kernel_size=3)
        self.bn2 = nn.batchnorm2d(64)
        self.dropout = nn.dropout(0.25)
        self.fc1 = nn.linear(64*12*12, 128)
        self.fc2 = nn.linear(128, 10)
        self.custom_param = nn.parameter(torch.randn(128, 128))

    def forward(self, x):
        x = f.relu(self

.bn1(self.conv1(x)))
        x = f.max_pool2d(x, 2)
        x = f.relu(self.bn2(self.conv2(x)))
        x = f.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = f.relu(self.fc1(x))
        x = x @ self.custom_param
        x = self.fc2(x)
        return f.log_softmax(x, dim=1)

model = mycomplexmodel()
criterion = nn.crossentropyloss()
optimizer = optim.adam(model.parameters(), lr=0.001)

for epoch in range(10):
    optimizer.zero_grad()
    input = torch.randn(64, 1, 28, 28)
    target = torch.randint(0, 10, (64,))
    output = model(input)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(f'epoch {epoch+1}, loss: {loss.item()}')

通过以上示例,可以更清晰地理解 torch.nn 模块的整体架构、原理、要点及其具体使用场景。

到此这篇关于pytorch中torch.nn模块的实现的文章就介绍到这了,更多相关pytorch torch.nn模块内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com