Pytorch学习笔记:加载预训练模型
前言
torchvision是官方经常需要的包,包括: torchvision.datasets:预定义的训练集(比如MNIST、CIFAR10等); torchvision.models:包含预定义好的经典网络结构(比如AlexNet、VGG、ResNet等), torchvision.transforms:数据增强的方法
模型地址: 官方文档:
摘要:
本文记录的是读取已有网络结构和添加预训练模型
1.加载模型
代码如下:
import torchvision.models as models resnet50 = models.resnet50(pretrained=True)如果不需要采用torchvision预训练模型参数来初始化,将pretrained设置为False:
resnet50 = models.resnet50(pretrained=False)
2.修改加载模型
以resnet为例,默认的是ImageNet的1000类,比如我们要做二分类,分类猫和狗:
import torch import torch.nn as nn resnet.fc = nn.Linear(2048, 2)
此处复习了原模型中最后的全连接层。
3.加载预训练模型
在实际使用时,通常都会对预训练网络进行修改,那么预训练的参数就不能完全的使用,对两者进行比对,选择相同的参数加载进来
#加载model,model是自己定义好的模型 resnet50 = models.resnet50(pretrained=True) model =Net(...) #读取参数 预训练参数和当前网络参数 pretrained_dict =resnet50.state_dict() model_dict = model.state_dict() #将pretrained_dict里不属于model_dict的键剔除掉 pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新现有的model_dict model_dict.update(pretrained_dict) # 加载我们真正需要的state_dict model.load_state_dict(model_dict)
4.保存模型
方法1:只保存参数,模型自己调用,以resnet为例:
#只保存参数 torch.save(resnet50.state_dict(),ckp/model.pth) #先导入模型结构,再调用保存的参数 resnet=resnet50(pretrained=True) resnet.load_state_dict(torch.load(ckp/model.pth))
方法2:模型和参数全部保存:
#保存 torch.save (model, PATH) #恢复 model = torch.load(PATH)
回顾总结
1.如何调用官方模型 2.如何改写网络 3.如何加载预训练网络:
- 读取网络参数
- 剔除预训练网络中的不属于当前网络的参数
- 更新预训练模型的网络参数
- 加载更新的网络参数
4.保存和恢复模型的两种方法
参考:
下一篇:
晶振原理详解及测试方法