Skip to content
Snippets Groups Projects

Draft: WIP Training checkpointing

Open Alexandru-Mihai GHERGHESCU requested to merge training_checkpointing into main
+ 12
10
@@ -92,10 +92,10 @@ class TrainerState():
self.current_epoch = 0
# this is weight updates, in total, over the whole training loop, with
# however many epochs; with gradient accumulation steps > 1, there
# will be more forward passes, but this is not captured here; global
# steps doesn't care about epochs, but is rather equal to epochs *
# steps_per_epoch
# however many epochs; with gradient accumulation steps > 1, there will
# be more forward + backward passes than weight updates, but this is not
# captured here; global_steps doesn't care about epochs, but is rather
# equal to epochs * steps_per_epoch
self.global_steps = 0
def to_json_file(self, file: Path):
@@ -249,14 +249,14 @@ class Trainer():
self.train_dataloader.set_epoch(epoch)
# if resume from checkpoint, update learning rate
for step in range(0, self.state.global_steps):
for _ in range(0, self.state.global_steps):
self.args.lr_scheduler.step()
for step, inputs in enumerate(self.train_dataloader):
# if resume from checkpoint, skip the first batches of data in
# the train dataloader
if step < (self.state.global_steps % (self.state.current_epoch + 1)):
if step < (self.state.global_steps % (self.state.current_epoch + 1)) * self.args.gradient_accumulation_steps:
continue
inputs = inputs['input_ids']
@@ -302,12 +302,12 @@ class Trainer():
self.args.lr_scheduler.step()
# lr = self.args.lr_scheduler.get_last_lr()[0]
if step % self.args.log_steps == 0:
if self.state.global_steps % self.args.log_steps == 0:
logger.info(f"Loss is {tr_loss:,}")
self.progress.update(1)
if step % self.args.save_steps == 0:
print(f'step: {self.state.global_steps}, save: {self.args.save_steps}')
if self.state.global_steps % self.args.save_steps == 0:
print(f'rank: {dist_utils.get_rank()}')
if dist_utils.get_rank() == 0: # TODO this only works for DDP, not FSDP
# save checkpoint
print(f'saving checkpoint at step {self.state.global_steps}')
@@ -317,6 +317,8 @@ class Trainer():
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
self.progress.update(1)
def save_model(self, save_dir: Path) -> None:
"""
Save model and tokenizer to a directory.
Loading