diff --git a/optimus/trainer.py b/optimus/trainer.py index 00497ab257179c1ff179c0e49a20103f59296d97..3c22d7ebda404807de8c1c7550e01c68edddc15a 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -8,12 +8,22 @@ import torch.nn as nn from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from optimus.utils import dist_utils, logging_utils +from optimus.utils import dist_utils, logging_utils, checkpoint_utils logger = logging.getLogger(__name__) +# files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +OPTIMIZER_NAME_BIN = "optimizer.bin" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" +FSDP_MODEL_NAME = "pytorch_model_fsdp" + + class TrainingArguments(): """ Training arguments class to hold important switches and knobs related to @@ -81,6 +91,37 @@ class TrainingArguments(): self.save_limit = save_limit +class TrainerState(): + """ + State of a trainer. This contains information that needs to be persisted + across checkpoints, such as current epoch or current number of steps. + """ + def __init__(self): + self.current_epoch = 0 + self.global_steps = 0 + + def to_json_file(self, file: Path): + """ + Save the current trainer state as a json file. + + Args: + file (Path): Path where the trainer state json file should be saved. + """ + pass # TODO + + def from_json_file(self, file: Path): + """ + Reload the state of a trainer from a json file. Note: It is very likely + that training *will not work as expected* if you resume training from a + checkpoint with a trainer state that has been manually altered. + + Args: + file (Path): Path to a json file (usually contained inside a + checkpoint directory), + """ + pass # TODO + + class Trainer(): """ Generic PyTorch trainer implementation. @@ -112,6 +153,10 @@ class Trainer(): self.eval_dataloader = eval_dataloader self.tokenizer = tokenizer + # initialize an empty trainer state (if resuming from a checkpoint, + # will be updated later) + self.state = TrainerState() + if self.args.log_on_main_rank_only: logger.addFilter(logging_utils.FilterMainRankOnly()) @@ -119,6 +164,13 @@ class Trainer(): """ Training loop of the trainer. + Args: + resume_from_checkpoint (bool or Path): Whether training should + resume from a checkpoint or start from scratch. If yes, then + training resumes from the checkpoint with the latest + modification time saved in `checkpoints_dir` (see + `TrainingArguments`). + """ num_examples = len(self.train_dataloader) * self.args.per_device_batch_size num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps @@ -134,6 +186,26 @@ class Trainer(): # scaler used for mixed precision fp16 training on GPU self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16) + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + + # potentially resume from saved checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + last_checkpoint: Path = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir) + if last_checkpoint is not None: + logger.info(f"Resuming training from checkpoint {str(last_checkpoint)}") + + checkpoint_utils.load_from_checkpoint( + last_checkpoint, + self.state, + self.model, + self.optimizer, + self.scheduler, + #self.scaler, # ? + ) + + self.model.train() + logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num epochs = {self.args.num_train_epochs:,}") @@ -142,16 +214,24 @@ class Trainer(): logger.info(f" Global batch size (w. distributed & accumulation) = {global_batch_size:,}") logger.info(f" Total optimization steps = {max_steps:,}") - self.model.train() - - # start training for num_train_epochs - for epoch in range(self.args.num_train_epochs): + # start training for num_train_epochs (or less if resumed from checkpoint) + for epoch in range(self.state.current_epoch, self.args.num_train_epochs): # needed for distributed sampler RNG state if hasattr(self.train_dataloader, "set_epoch"): self.train_dataloader.set_epoch(epoch) + # resume from checkpoint, update steps + # TODO grad_acc_steps for this? + for step in range(0, self.state.global_steps): + self.args.lr_scheduler.step() + for step, inputs in enumerate(self.train_dataloader): + steps_trained_in_current_epoch += 1 + + # check if resume from training + if steps_trained_in_current_epoch < self.state.global_steps: + continue inputs = inputs['input_ids'] inputs = inputs.to(self.args.device) @@ -164,7 +244,7 @@ class Trainer(): logits = logits[..., :-1, :].contiguous().view( -1, self.model.module.vocab_size - if dist_utils.is_dist_available_and_initialized() + if dist_utils.is_dist_available_and_initialized() # TODO this should probably be a hasattr(self.model, 'module') else self.model.vocab_size ) diff --git a/optimus/utils/checkpoint_utils.py b/optimus/utils/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4f98a59041f99ef8a54bb34213ab2d181480dd --- /dev/null +++ b/optimus/utils/checkpoint_utils.py @@ -0,0 +1,36 @@ +import os +import re +import logging +from pathlib import Path + + +logger = logging.getLogger(__name__) + + +def get_last_checkpoint_dir(dir: Path): + """ + Get last saved (highest global step) checkpoint in a directory, matching the + pattern `checkpoint-<number>`, where number is the global step of training. + Each checkpoint should be in itself another directory. + + Args: + dir (Path): The checkpoints directory. + + Returns (Path or None): Returns the path to the last checkpoint, or `None` + if no valid checkpoints were found. + + """ + checkpoint_regex = re.compile(r"^checkpoint-(\d+)$") + checkpoints_dirs = os.listdir(dir) + checkpoints = [ + path + for path in checkpoints_dirs + if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path)) + ] + if len(checkpoints) == 0: + return None + return os.path.join( + dir, + max(checkpoints, + key=lambda x: int(checkpoint_regex.search(x).groups()[0])) + )