当前位置: 代码网 > 科技>电脑产品>CPU > pytorch GPU和CPU模型相互加载方式

pytorch GPU和CPU模型相互加载方式

2024年09月09日 CPU 我要评论
1 pytorch保存模型的两种方式1.1 直接保存模型并读取# 创建你的模型实例对象: modelmodel = net()## 保存模型torch.save(model, 'model_name.

1 pytorch保存模型的两种方式

1.1 直接保存模型并读取

# 创建你的模型实例对象: model
model = net()
## 保存模型
torch.save(model, 'model_name.pth')

## 读取模型
model = torch.load('model_name.pth')

1.2 只保存模型中的参数并读取

## 保存模型
torch.save({'model': model.state_dict()}, 'model_name.pth')

## 读取模型
model = net()
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])
  • 第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。
  • 第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

如何保存模型决定了如何读取模型,一般来选择第二种来保存和读取。

2 gpu / cpu模型相互加载

2.1 单个cpu和单个gpu模型加载

pytorch 允许把在gpu上训练的模型加载到cpu上,也允许把在cpu上训练的模型加载到gpu上。

加载模型参数的时候,在gpu和cpu训练的模型是不一样的,这两种模型是不能混为一谈的,下面分情况进行操作说明。

情况一:cpu -> cpu, gpu -> gpu

  • gpu训练的模型,在gpu上使用;
  • cpu训练的模型,在cpu上使用,

这种情况下我们都只用直接用下面的语句即可:

torch.load('model_dict.pth')

情况二:gpu -> cpg/gpu

gpu训练的模型,不知道放在cpu还是gpu运行,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的gpu模型权重文件
weights_path = 'model_gpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的cuda设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
weights = torch.load(weights_path, map_location=device)
model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("model is successfully loaded and can be used on a", device.type, "!")

情况三:cpu -> cpg/gpu

模型是在cpu上训练的,但不确定要在cpu还是gpu上运行时,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的cpu模型权重文件
weights_path = 'model_cpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的cuda设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
if device.type == 'cuda':
    model.to(device)
    weights = torch.load(weights_path, map_location=device)
else:
    weights = torch.load(weights_path, map_location='cpu')

model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("model is successfully loaded and can be used on a", device.type, "!")

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持代码网。

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com