Skip to content
Snippets Groups Projects

Draft: WIP Training checkpointing

Open Alexandru-Mihai GHERGHESCU requested to merge training_checkpointing into main
Files
2
+ 8
7
@@ -300,13 +300,14 @@ class Trainer():
self.progress.update(1)
if steps_trained_in_current_epoch % self.args.save_steps:
# save checkpoint
print(f'saving checkpoint at step {steps_trained_in_current_epoch}')
os.mkdir(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'), parent=True)
self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
torch.save(self.model.state_dict(),
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
if dist_utils.get_rank() == 0: # TODO this only works for DDP
# save checkpoint
print(f'saving checkpoint at step {steps_trained_in_current_epoch}')
os.makedirs(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'))
self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
torch.save(self.model.state_dict(),
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
def save_model(self, save_dir: Path) -> None:
"""
Loading