当前位置: 代码网 > it编程>前端脚本>Python > 如何使用PyTorch优化一个边缘检测器

如何使用PyTorch优化一个边缘检测器

2024年09月23日 Python 我要评论
import torchimport torch.nn as nnx = torch.tensor([[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0

import torch
import torch.nn as nn

x = torch.tensor([[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0]], dtype=float)
y = torch.tensor([[0,30,30,0],[0,30,30,0],[0,30,30,0],[0,30,30,0]], dtype=float)

conv2d = nn.conv2d(1,1,kernel_size=(3,3), bias=false, dtype=float)

x = x.reshape((1,1,6,6))
y = y.reshape((1,1,4,4))
lr = 0.0005

optim = torch.optim.rmsprop(conv2d.parameters(), lr=lr)
loss_fn = torch.nn.mseloss()
for i in range(4000):
    y_pred = conv2d(x)
    loss = loss_fn(y_pred, y)
    # 更新参数
    if 0: # 手动更新
        conv2d.zero_grad()
        loss.backward()
        conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if 10: # 使用优化器更新
        optim.zero_grad()
        loss.backward()
        optim.step()
    if(i + 1) % 100 == 0:
        print(f'epoch {i+1}, loss {loss.sum():.4f}')

# 打印训练的参数
print(conv2d.weight.data.reshape(3,3))

输出:

epoch 100, loss 331.4604
epoch 200, loss 284.8803
epoch 300, loss 248.8032
epoch 400, loss 218.8007
epoch 500, loss 193.1186
epoch 600, loss 170.4061
epoch 700, loss 149.4530
epoch 800, loss 129.7580
epoch 900, loss 111.4134
epoch 1000, loss 94.5393
epoch 1100, loss 79.1782
epoch 1200, loss 65.3312
epoch 1300, loss 52.9822
epoch 1400, loss 42.1062
epoch 1500, loss 32.6718
epoch 1600, loss 24.6388
epoch 1700, loss 17.9555
epoch 1800, loss 12.5522
epoch 1900, loss 8.3332
epoch 2000, loss 5.1700
epoch 2100, loss 2.9096
epoch 2200, loss 1.4077
epoch 2300, loss 0.5341
epoch 2400, loss 0.1348
epoch 2500, loss 0.0166
epoch 2600, loss 0.0006
epoch 2700, loss 0.0000
epoch 2800, loss 0.0001
epoch 2900, loss 0.0001
epoch 3000, loss 0.0001
epoch 3100, loss 0.0001
epoch 3200, loss 0.0002
epoch 3300, loss 0.0002
epoch 3400, loss 0.0002
epoch 3500, loss 0.0002
epoch 3600, loss 0.0002
epoch 3700, loss 0.0002
epoch 3800, loss 0.0002
epoch 3900, loss 0.0002
epoch 4000, loss 0.0002
tensor([[ 1.3123, -0.0050, -1.0276],
        [ 0.8334,  0.0677, -0.8868],
        [ 0.8551, -0.0619, -1.0849]], dtype=torch.float64)

由训练出的结果可以看出卷积核参数与实际的卷积核挺接近了。

到此这篇关于如何使用pytorch优化一个边缘检测器的文章就介绍到这了,更多相关pytorch优化边缘检测器内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com