本文正在参与「金石方案 . 分割6万现金大奖」

作者简介:秃头小苏,致力于用最浅显的语言描述问题

往期回忆:对立生成网络GAN系列——GAN原理及手写数字生成小事例    对立生成网络GAN系列——DCGAN简介及人脸图像生成事例

近期目标:写好专栏的每一篇文章

支撑小苏:点赞、收藏⭐、留言

pytorch保存与加载模型详解篇

写在前面

  最近,看到不少小伙伴问pytorch怎么保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击☞☞☞了解概况

​  可是必定有很多人是不愿意看官网的,所以我仍是花一篇文章来为咱们介绍介绍。当然了,在介绍中我会参加自己的一些了解,让咱们有一个更深的知道。假如预备好了的话,就让咱们开端吧。⏳⏳⏳

模型保存与加载

  pytorch中介绍了几种不同的模型保存和加载办法,我会在下文逐个为咱们介绍。首要先让咱们来随意界说一个模型,如下:【用的是pytorch官网的例子】

# 模型界说
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

  界说好模型结构后,咱们能够实例化这个模型:

#模型初始化
model = TheModelClass()

  模型初始化过后,咱们就一起来看看模型保存和加载的办法吧。

办法1

  办法1是官方引荐的一种办法,咱们直接来看代码好了,如下:

# 保存模型
torch.save(model.state_dict(), './model/model_state_dict.pth')

​  该办法后边的参数'./model/model_state_dict.pth'为模型的保存途径,模型后缀名官方引荐运用.pth.pt,当然了,你取别的后缀名也是完全可行的。☘☘☘

  介绍了模型的保存,下面就来看看办法1是怎么加载模型的。【这儿我说明一点,模型保存往往是在练习中进行的,而模型加载大都用在模型推理中,它们存在两个文件中,故咱们在推理过程中要先实列化模型】

# 加载模型
model_test1 = TheModelClass()   # 加载模型时应先实例化模型
# load_state_dict()函数接收一个字典,所以不能直接将'./model/model_state_dict.pth'传入,而是先运用load函数将保存的模型参数反序列化
model_test1.load_state_dict(torch.load('./model/model_state_dict.pth'))
model_test1.eval()    # 模型推理时设置

​  在上述的代码注释中我有写到,咱们运用load_state_dict()加载模型时先需求运用load办法将保存的模型参数==反序列化==,load后的结果是一个字典,这时就能够经过load_state_dict()办法来加载了。


  这儿我来简略说一下我了解的反序列化,其和序列化是相对应的一个概念。序列化便是把内存中的数据保存到磁盘中,像咱们运用torch.save()办法保存模型便是序列化;而反序列化则是将硬盘中的数据加载到内存当中,显然咱们加载模型的过程便是反序列化过程。【大致的意思如下图所示,偶尔在水群的时分看到一个画图软件,是不是还挺好看的】

pytorch模型保存、加载与续训练


办法2

办法2十分简略,直接上代码:

# 保存模型
torch.save(model, './model/model.pt')    #这儿咱们保存模型的后缀名取.pt
# 加载模型
model_test2 = torch.load('./model/model.pt')     
model_test2.eval()   # 模型推理时设置

  可是这种办法是不引荐运用的,因为你运用这种办法保存模型,然后再加载时会遇到各式各样的过错。为了加深咱们了解,咱们来看这样的一个例子。文件的结构如下图所示:

pytorch模型保存、加载与续训练

models.py文件中存储的是模型的界说,其坐落文件夹models下。save_model.py文件中写的是保存模型的代码,如下:

from models.models import TheModelClass
from torch import optim
import torch
#模型初始化
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# ## 保存加载办法2——save/load
# # 保存模型
# torch.save(models, './models/models.pt')

履行此文件后,会生成models.pt文件,咱们在履行load_mode.py文件即可完成加载,load_mode.py内容如下:

from models.models import TheModelClass
import torch
## 加载办法2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./models/models.pt')     
model_test2.eval()   # 模型推理时设置
print(model_test2)

此刻咱们能够正常加载。但假如咱们将models文件夹修改为model,如下:

pytorch模型保存、加载与续训练

​此刻咱们在运用如下代码加载模型的话就会呈现过错:

from models.models import TheModelClass
import torch
## 加载办法2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./model/models.pt')     #这儿需求修改一下文件途径  
model_test2.eval()   # 模型推理时设置
print(model_test2)

pytorch模型保存、加载与续训练

​  呈现这种过错的原因是运用办法2进行模型保存的时分会把模型结构界说文件途径记载下来,加载的时分就会依据途径解析它然后装载参数;当把模型界说文件途径修改以后,运用torch.load(path)就会报错。


  其实运用办法2进行模型的保存和加载还会存在各种问题,感兴趣的能够看看这篇博文。总归,在咱们往后的运用中,尽量不要用办法2来加载模型。

办法3

​  pytorch还为咱们提供了一种模型保存与加载的办法——checkpoint。这种办法保存的是一个字典,假如咱们程序在运转中因为某种原因异常中止,那么这种办法能够很方便的让咱们接着上次练习,正因为这样,我十分引荐咱们运用这种办法进行模型的保存与加载。下面就让咱们一起来看看办法3是怎么运用的吧!!!

  首要,咱们同样运用torch.save来保存模型,可是这儿保存的是一个字典,里面能够填入你需求保存的参数,如下:

# 保存checkpoint
torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'loss':loss
            }, './model/model_checkpoint.tar'    #这儿的后缀名官方引荐运用.tar
            )

​ 接着咱们来看看怎么加载checkpoint,代码如下:

# 加载checkpoint
model_checkpoint = TheModelClass()
optimizer =  optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar')    # 先反序列化模型
model_checkpoint.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

  看了我上文的介绍,咱们是否知道怎么运用checkpoint了呢,我想咱们都会觉得这个不是很难,但要自己写可能仍是不好掌握,那么第一次就让我来带领咱们看看怎么在代码中运用checkpoint吧!!!

​  这节我选用cifar10数据集完成物体分类的例子,我的这篇博文对其进行了详细介绍,那么这儿介绍checkpoint我将利用这个demo来为咱们讲解。首要咱们直接来看模型保存的完好代码,如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、预备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、建立神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self, input):
        input = self.model1(input)
        return input
#4、创立网络模型
net = Net()
#5、设置丢失函数、优化器
#丢失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate)   #SGD:梯度下降算法
#6、设置网络练习中的一些参数
total_train_step = 0   #记载总计练习次数
total_test_step = 0    #记载总计测验次数
Max_epoch = 10    #规划练习轮数
#7、开端进行练习
for epoch in range(Max_epoch):
    print("---第{}轮练习开端---".format(epoch))
    net.train()     #开端练习,不是有必要的,在网络中有BN,dropout时需求
    #因为练习集数据较多,这儿我没用练习集练习,而是选用测验集(test_dataset_loader)当练习集,但思维是共同的
    for data in test_dataset_loader:      
        imgs, targets = data
        targets = targets.to(device)
        outputs = net(imgs)
        #比较输出与真实值,计算Loss
        loss = loss_fun(outputs, targets)
        #反向传达,调整参数
        optimizer.zero_grad()    #每次让梯度重置
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 50 == 0:
            print("---第{}次练习完毕, Loss:{})".format(total_train_step, loss.item()))
    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这儿的后缀名官方引荐运用.tar
        )
    if epoch > 5:
        print("---意外中断---")
        break

  整个流程和这篇文章根本共同,不清楚的主张先花几分钟阅览一下哈。主要区别便是在最后保存模型的时分我运用了checkpoint进行保存,且两个epoch保存一次。当epoch=6时,我设置了一个break模拟程序意外中断,中断后能够来看一下终端的输出信息,如下图所示:

pytorch模型保存、加载与续训练

  咱们能够看到在进行第6轮循环时,程序中断了,此刻最新的保存的模型是第五次练习结果,如下:

pytorch模型保存、加载与续训练

  一起注意到第5次练习完毕的loss在2.0左右,假如咱们下次接着练习,丢失应该是在2.0附近。


​  好了,上面因为一些糟糕的原因导致程序中断了,现在我想接着上次练习的结果继续练习,我该怎么办呢?代码如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、预备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、建立神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self, input):
        input = self.model1(input)
        return input
#4、创立网络模型
net = Net()
#5、设置丢失函数、优化器
#丢失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate)   #SGD:梯度下降算法
#6、设置网络练习中的一些参数
total_train_step = 0   #记载总计练习次数
total_test_step = 0    #记载总计测验次数
Max_epoch = 10    #规划练习轮数
##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar')    # 先反序列化模型
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################
#7、开端进行练习
for epoch in range(start_epoch+1, Max_epoch):
    print("---第{}轮练习开端---".format(epoch))
    net.train()     #开端练习,不是有必要的,在网络中有BN,dropout时需求
    for data in test_dataset_loader:
        imgs, targets = data
        targets = targets.to(device)
        outputs = net(imgs)
        #比较输出与真实值,计算Loss
        loss = loss_fun(outputs, targets)
        #反向传达,调整参数
        optimizer.zero_grad()    #每次让梯度重置
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 50 == 0:
            print("---第{}次练习完毕, Loss:{})".format(total_train_step, loss.item()))
    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这儿的后缀名官方引荐运用.tar
        )

  这儿的代码相较之前的多了一个加载checkpoint的过程,我将其截取出来,如下图所示:

pytorch模型保存、加载与续训练

  经过加载checkpoint咱们就保存了之前练习的参数,从而完成断点续练习,咱们直接来看履行此代码的结果,如下图所示:

pytorch模型保存、加载与续训练

​  从上图能够看出咱们的练习是从第6轮开端的,并且初始的loss为1.99,和2.0挨近。这就说明了咱们现已完成了中断后恢复练习的操作。

  这儿我简略的说两句,上文介绍checkpoint的用法时,练习中断和练习恢复我是放在两个文件中的进行的,可是在实践中咱们必定是在一个文件中运转,那这该怎么办呢?其实办法很简略啦,咱们只需求设置一个if条件将加载checkpoint的部分放在练习文件中,然后设置一个参数来控制if条件的履行即可。具体细节我就不给咱们介绍了,假如有不明白的谈论区见吧!!!

总结

  这部分仍是蛮简略的,但一些细节仍是需求咱们自行考量,我就为咱们介绍到这儿啦,希望咱们都能够有所收成吧。

如若文章对你有所帮助,那就

        

pytorch模型保存、加载与续训练