pytorch值保存加载模型
当提到保存和加载模型时,有三个核心功能需要熟悉:
torch.save:将序列化的对象保存到disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。
[torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch load#torch.load):使用pickle unpickle工具将pickle的对象文件反序列化为内存。
torch.nn.Module.load_state_dict:使用反序列化状态字典加载model’s参数字典。
model 的STATE_DICT
在PyTorch中,torch.nn.Module的可学习参数(即权重和偏差),模块模型包含在model’s参数中(通过model.parameters()访问)。state_dict是个简单的Python dictionary对象,它将每个层映射到它的参数张量。
注意,只有具有可学习参数的层(卷积层、线性层等)才有model’s state_dict中的条目。优化器对象(connector .optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。
Example:
1 | import torch |
output
1 | Model's state_dict: |
保存和导入模型
- save
1 | torch.save(model.state_dict(), PATH) |
在保存模型进行推理时,只需要保存训练过的模型的学习参数即可。一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
- load
1 | model = TheModelClass(*args, **kwargs) |
当然也可以这样读取和保存模型
- save
1 | torch.save(model, PATH) |
- load
1 | # Model class must be defined somewhere |