当前位置: 代码网 > it编程>编程语言>Asp.net > pytorch模型保存方式

pytorch模型保存方式

2024年09月09日 Asp.net 我要评论
pytorch模型保存保存模型主要分为两类:保存整个模型只保存模型参数1.保存加载整个模型(不推荐)保存整个网络模型,网络结构+权重参数torch.save(model,'net.pth')加载整个网

pytorch模型保存

保存模型主要分为两类:

  • 保存整个模型
  • 只保存模型参数

1.保存加载整个模型(不推荐)

保存整个网络模型,网络结构+权重参数

torch.save(model,'net.pth')

加载整个网络模型(可能比较耗时)

model=torch.load('net.pth')

2.只保存加载模型参数(推荐)

保存模型的权重参数(速度快,占内存少)

torch.save(model.state_dict(),'net_params.pth')

load 模型参数

因为我们只保存了 模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。

model=mynet()

#将模型参数加载到新模型中,torch.load返回的是一个ordereddict,说明.state_dict()只是把所有模型的参数都已ordereddict的形式存下来。

state_dict=torch.load('net_params.pth')
model.load_state_dict(state_dict)

note:保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。

load_state_dict的参数strict=false

new_model.load_state_dict(state_dict,strict=false)

如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key.

但是加上strict=false可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持代码网。

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com