python
主页 > 脚本 > python >

pytorch GPU和CPU模型相互加载方式

2024-09-09 | 佚名 | 点击:

1 pytorch保存模型的两种方式

1.1 直接保存模型并读取

1

2

3

4

5

6

7

# 创建你的模型实例对象: model

model = net()

## 保存模型

torch.save(model, 'model_name.pth')

 

## 读取模型

model = torch.load('model_name.pth')

1.2 只保存模型中的参数并读取

1

2

3

4

5

6

7

## 保存模型

torch.save({'model': model.state_dict()}, 'model_name.pth')

 

## 读取模型

model = net()

state_dict = torch.load('model_name.pth')

model.load_state_dict(state_dict['model'])

如何保存模型决定了如何读取模型,一般来选择第二种来保存和读取。

2 GPU / CPU模型相互加载

2.1 单个CPU和单个GPU模型加载

pytorch 允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上。

加载模型参数的时候,在GPU和CPU训练的模型是不一样的,这两种模型是不能混为一谈的,下面分情况进行操作说明。

情况一:CPU -> CPU, GPU -> GPU

这种情况下我们都只用直接用下面的语句即可:

1

torch.load('model_dict.pth')

情况二:GPU -> CPG/GPU

GPU训练的模型,不知道放在CPU还是GPU运行,两种情况都要考虑

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

import torch

from torchvision import models

 

# 加载预训练的GPU模型权重文件

weights_path = 'model_gpu.pth'

 

# 定义一个与原模型结构相同的新模型

model = models.resnet50()

 

# 检查是否有可用的CUDA设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

# 将权重映射到相应的设备内存并加载到模型中

weights = torch.load(weights_path, map_location=device)

model.load_state_dict(weights)

 

# 设置为评估模式

model.eval()

 

print("Model is successfully loaded and can be used on a", device.type, "!")

情况三:CPU -> CPG/GPU

模型是在CPU上训练的,但不确定要在CPU还是GPU上运行时,两种情况都要考虑

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

from torchvision import models

 

# 加载预训练的CPU模型权重文件

weights_path = 'model_cpu.pth'

 

# 定义一个与原模型结构相同的新模型

model = models.resnet50()

 

# 检查是否有可用的CUDA设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

# 将权重映射到相应的设备内存并加载到模型中

if device.type == 'cuda':

    model.to(device)

    weights = torch.load(weights_path, map_location=device)

else:

    weights = torch.load(weights_path, map_location='cpu')

 

model.load_state_dict(weights)

 

# 设置为评估模式

model.eval()

 

print("Model is successfully loaded and can be used on a", device.type, "!")

原文链接:
相关文章
最新更新