1 构造函数
2 _build_optimizer
根据配置中指定的优化器类型创建并返回一个适合用于模型训练的优化器对象
3 _build_scheduler
构建一个学习率调度器(scheduler)
4 train
5 run
6 _valid_epoch
7 load_model & save_model
保存/加载模型的状态字典 (self.model.state_dict()
) 和优化器的状态字典 (self.optimizer.state_dict()
)
def load_model(self, cache_name):
model_state, optimizer_state = torch.load(cache_name)
self.model.load_state_dict(model_state)
self.optimizer.load_state_dict(optimizer_state)
def save_model(self, cache_name):
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# save optimizer when load epoch to train
torch.save((self.model.state_dict(), self.optimizer.state_dict()), cache_name)