避免数据丢失¶
在训练模型的过程中,偶尔可能会发生网络波动、显卡故障、内存溢出、异常终止等问题。 这些问题发生将可能会导致训练进程终止,前面训练将功亏一篑,又要重头开始训练,不仅浪费时间并且浪费精力,因此每间隔一段时间就将训练模型信息保存到磁盘一次显得尤为重要,而这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复训练。
对于 Pytorch
框架官方提供了checkpoint
功能,用来定期保存模型和恢复模型,具体可参考官方文档:https://pytorch.org/tutorials/beginner/saving_loading_models.html#
对于 tensorflow
框架官方也提供了类似于 Pytorch 框架的 checkpoint
功能,在 tensorflow
框架中被称为 ModelCheckpoint
,用来在训练期间保存模型,具体可参考官方文档:https://www.tensorflow.org/tutorials/keras/save_and_load?hl=zh-cn