一、state_dict方式(推荐)
torch.save(model.state_dict(), PATH)
model = YourModel()
model.load_state_dict(torch.load(PATH))
model.eval()
记住一定要使用model.eval()
来固定dropout
和归一化
层,否则每次推理会生成不同的结果。
二、整个模型(结构+state_dict)方式
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
这种保存/加载模型的过程使用了最直观的语法,所用代码量少。这使用Python的pickle
保存所有模块。这种方法的缺点是,保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。这是因为pickle不保存模型类本身
,而是保存这个类的路径,并且在加载的时候会使用。因此,当在其他项目里使用或者重构的时候,加载模型的时候会出错。
记住一定要使用model.eval()
来固定dropout
和归一化
层,否则每次推理会生成不同的结果。
三、cptk方式
当我们在训练的时候,因为一些原因导致训练终止了,这个时候如果我们不想再浪费时间从头开始训练,就可以使用cptk的方式。这种方式不仅可以保存模型的state_dict,还可以保存训练中断时的训练的epoch,loss,优化器的state_dict等信息。
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = yourModel()
optimizer = yourOptimizer()
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - 或者 -
model.train()
示例