PyTorch模型的保存与加载是怎么样的

PyTorch模型的保存与加载是怎么样的,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

成都创新互联公司是专业的鄂温克网站建设公司,鄂温克接单;提供网站设计、成都网站建设,网页设计,网站设计,建网站,PHP网站建设等专业做网站服务;采用PHP框架,可快速的进行鄂温克网站开发网页制作和功能扩展;专业做搜索引擎喜爱的网站,专业的做网站团队,希望更多企业前来合作!

torch.save()和torch.load():

torch.save()和torch.load()配合使用,
分别用来保存一个对象(任何对象,
不一定要是PyTorch中的对象)到文件,和从文件中加载一个对象.
加载的时候可以指明是否需要数据在CPU和GPU中相互移动.

Module.state_dict()和Module.load_state_dict():

Module.state_dict()返回一个字典,
该字典以键值对的方式保存了Module的整个状态.

Module.load_state_dict()可以从一个字典中加载参数到这个module和其后代,
如果strict是True,
那么所加载的字典和该module本身state_dict()方法返回的关键字必须严格确切的匹配上.
If strict is True, 
then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
返回值是一个命名元组:
NamedTuple with missing_keys and unexpected_keys fields,
分别保存缺失的关键字和未预料到的关键字.
如果自己的模型跟预训练模型只有部分层是相同的,
那么可以只加载这部分相同的参数,
只要设置strict参数为False来忽略那些没有匹配到的keys即可。
# 方式1:# model_path = 'model_name.pth'# model_params_path = 'params_name.pth'# ----保存----# torch.save(model, model_path)# ----加载----# model = torch.load(model_path)# 方式2:#----保存----# torch.save(model.state_dict(), model_params_path) #保存的文件名后缀一般是.pt或.pth
#----加载----# model=Model().cuda() #定义模型结构
# model.load_state_dict(torch.load(model_params_path))  #加载模型参数

说明:

# 保存/加载整个模型
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
这种保存/加载模型的过程使用了最直观的语法,
所用代码量少。这使用Python的pickle保存所有模块。
这种方法的缺点是,保存模型的时候,
序列化的数据被绑定到了特定的类和确切的目录。
这是因为pickle不保存模型类本身,而是保存这个类的路径,
并且在加载的时候会使用。因此,
当在其他项目里使用或者重构的时候,加载模型的时候会出错。




# 保存/加载 state_dict(推荐)
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

自己选择要保存的参数,设置checkpoint:

#----保存----torch.save({
   
   
   'epoch': epoch + 1,'arch': args.arch,'state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
  	'loss': loss,'best_prec1': best_prec1,}, 'checkpoint_name.tar' )#----加载----checkpoint = torch.load('checkpoint_name.tar')#按关键字获取保存的参数
start_epoch = checkpoint['epoch']best_prec1 = checkpoint['best_prec1']state_dict=checkpoint['state_dict']model=Model()#定义模型结构
model.load_state_dict(state_dict)

保存多个模型到同一个文件:

#----保存----torch.save({
   
   
   
  'modelA_state_dict': modelA.state_dict(),
  'modelB_state_dict': modelB.state_dict(),
  'optimizerA_state_dict': optimizerA.state_dict(),
  'optimizerB_state_dict': optimizerB.state_dict(),
  ...
  }, PATH)#----加载----modelA = TheModelAClass(*args, **kwargs)modelB = TheModelAClass(*args, **kwargs)optimizerA = TheOptimizerAClass(*args, **kwargs)optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)modelA.load_state_dict(checkpoint['modelA_state_dict']modelB.load_state_dict(checkpoint['modelB_state_dict']optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']modelA.eval()modelB.eval()# or
modelA.train()modelB.train()# 在这里,保存完模型后加载的时候有时会
# 遇到CUDA out of memory的问题,
# 我google到的解决方法是加上map_location=‘cpu’

checkpoint = torch.load(PATH,map_location='cpu')

加载预训练模型的部分:

resnet152 = models.resnet152(pretrained=True) #加载模型结构和参数
pretrained_dict = resnet152.state_dict()"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
   也可以直接从官方model_zoo下载:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}# 更新现有的model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

或者写详细一点:

model_dict = model.state_dict()state_dict = {
   
   
   }for k, v in pretrained_dict.items():if k in model_dict.keys():# state_dict.setdefault(k, v)state_dict[k] = velse:print("Missing key(s) in state_dict :{}".format(k))model_dict.update(state_dict)model.load_state_dict(model_dict)

关于PyTorch模型的保存与加载是怎么样的问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注创新互联行业资讯频道了解更多相关知识。

当前标题:PyTorch模型的保存与加载是怎么样的
浏览地址:https://www.cdcxhl.com/article22/gooecc.html

成都网站建设公司_创新互联,为您提供网站建设建站公司响应式网站Google外贸网站建设手机网站建设

广告

声明:本网站发布的内容(图片、视频和文字)以用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服。电话:028-86922220;邮箱:631063699@qq.com。内容未经允许不得转载,或转载时需注明来源: 创新互联

网站托管运营