当前位置: 代码网 > it编程>前端脚本>Python > 【Pytorch 学习笔记】def init_weights() 初始化参数

【Pytorch 学习笔记】def init_weights() 初始化参数

2024年07月28日 Python 我要评论
nn.init.constant_() 是 torch.nn 中的用于填充数值的函数,这里用于指定初始化值,还有许多其他函数可用于此。在生成网络 net 时,会指定 net 最初的权重,对于一些预训练好的模型权重,就可以放在这个部分进行加载。在 CNN 中,经常可以看见 init_weights() 函数,它是用来初始化网络参数的。并不是所有的层都能初始化权重的,比如 nn.MaxPool2d(),它是无法初始化的。init_weights(self, m) 中的 m 就是指 net 中的某一层。

在 cnn 中,经常可以看见 init_weights() 函数,它是用来初始化网络参数的。

以下面代码为例:

class lenet(nn.module):

    def __init__(self):
        super(lenet, self).__init__()
        self.conv1 = nn.conv2d(3, 6, 5)
        self.conv2 = nn.conv2d(6, 16, 5)
        self.conv3 = nn.conv2d(16, 120, 5)
        self.pool = nn.maxpool2d(2, 2)
        self.fc1 = nn.linear(120, 84)
        self.fc2 = nn.linear(84, 10)

        self.apply(self.init_weights) # 调用初始化函数
        
    def init_weights(self, m):
        if isinstance(m, nn.conv2d):
            nn.init.constant_(m.weight, 0.1)
        elif isinstance(m, nn.linear):
            nn.init.constant_(m.weight, 0.2)
            nn.init.constant_(m.bias, 1)

    def forward(self, x):
        x = self.pool(f.relu(self.conv1(x)))
        x = self.pool(f.relu(self.conv2(x)))
        x = f.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = f.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = lenet()

在生成网络 net 时,会指定 net 最初的权重,对于一些预训练好的模型权重,就可以放在这个部分进行加载。

我们打印 net 中的各层,如下:

print(net.modules)

---------------------------------------------------------------------

<bound method module.modules of lenet(
  (conv1): conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv3): conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (pool): maxpool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=false)
  (fc1): linear(in_features=120, out_features=84, bias=true)
  (fc2): linear(in_features=84, out_features=10, bias=true)
)>

init_weights(self, m) 中的 m 就是指 net 中的某一层。

isinstance() 函数来判断一个对象是否是一个已知的类型,isinstance(m, nn.conv2d) 就是判断 m 是不是 nn.conv2d。

nn.init.constant_() 是 torch.nn 中的用于填充数值的函数,这里用于指定初始化值,还有许多其他函数可用于此。

并不是所有的层都能初始化权重的,比如 nn.maxpool2d(),它是无法初始化的。

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com