Skip to content
Snippets Groups Projects

Draft: WIP Training checkpointing

Open Alexandru-Mihai GHERGHESCU requested to merge training_checkpointing into main
Files
2
+ 27
18
@@ -160,6 +160,7 @@ class Trainer():
self.state = TrainerState()
if self.args.log_on_main_rank_only:
# TODO this doesn't seem to work?
logger.addFilter(logging_utils.FilterMainRankOnly())
def train(self, resume_from_checkpoint: bool|Path = False) -> None:
@@ -174,9 +175,17 @@ class Trainer():
`TrainingArguments`).
"""
# number of examples in dataloader (globally, not per GPU); an example
# is a sequence of length context_len
num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
# weight update steps per epoch (globally, not per GPU)
num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
# total number of weight update steps
max_steps = self.args.num_train_epochs * num_update_steps_per_epoch
# global batch size
global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size()
fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@@ -186,15 +195,13 @@ class Trainer():
self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0)
# scaler used for mixed precision fp16 training on GPU
# TODO, with bf16, you don't need a scaler
self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
# initialize a Trainer state (will be overriden later if resuming from
# checkpoint)
self.state = TrainerState()
epochs_trained = 0
steps_trained_in_current_epoch = 0
# potentially resume from saved checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
last_checkpoint: Path | None = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir)
@@ -227,7 +234,7 @@ class Trainer():
self.model.train()
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Total number of examples in train dataloader = {num_examples:,}")
logger.info(f" Num epochs = {self.args.num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
@@ -241,15 +248,15 @@ class Trainer():
if hasattr(self.train_dataloader, "set_epoch"):
self.train_dataloader.set_epoch(epoch)
# resume from checkpoint, update learning rate
# if resume from checkpoint, update learning rate
for step in range(0, self.state.global_steps):
self.args.lr_scheduler.step()
for step, inputs in enumerate(self.train_dataloader):
# check if resume from checkpoint and skip the first batches of
# data in the dataloader
if steps_trained_in_current_epoch < (self.state.global_steps % (self.state.current_epoch + 1)):
# if resume from checkpoint, skip the first batches of data in
# the train dataloader
if step < (self.state.global_steps % (self.state.current_epoch + 1)):
continue
inputs = inputs['input_ids']
@@ -272,6 +279,7 @@ class Trainer():
tr_loss = loss.item()
loss = loss / self.args.gradient_accumulation_steps # normalize to account for gradient accumulation
# TODO only for fp16, not fp32 or bf16
self.scaler.scale(loss).backward()
# update only after gradient_accumulation_steps
@@ -280,7 +288,7 @@ class Trainer():
# when grad_acc_steps > 1 and the number of batches isn't
# divisible by grad_acc_steps
steps_trained_in_current_epoch += 1
self.state.global_steps += 1
# gradient clipping
self.scaler.unscale_(self.args.optimizer)
@@ -294,19 +302,20 @@ class Trainer():
self.args.lr_scheduler.step()
# lr = self.args.lr_scheduler.get_last_lr()[0]
if steps_trained_in_current_epoch % self.args.log_steps:
if step % self.args.log_steps == 0:
logger.info(f"Loss is {tr_loss:,}")
self.progress.update(1)
if steps_trained_in_current_epoch % self.args.save_steps:
# save checkpoint
print(f'saving checkpoint at step {steps_trained_in_current_epoch}')
os.mkdir(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'), parent=True)
self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
torch.save(self.model.state_dict(),
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
if step % self.args.save_steps == 0:
if dist_utils.get_rank() == 0: # TODO this only works for DDP, not FSDP
# save checkpoint
print(f'saving checkpoint at step {self.state.global_steps}')
os.makedirs(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'))
self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
torch.save(self.model.state_dict(),
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
def save_model(self, save_dir: Path) -> None:
"""
Loading