diff --git a/optimus/trainer.py b/optimus/trainer.py index 00497ab257179c1ff179c0e49a20103f59296d97..b25955e8048e6f69e6b4c605c41aacfb6915c5d1 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -1,3 +1,5 @@ +import os +import json import logging from pathlib import Path from typing import Optional @@ -8,7 +10,7 @@ import torch.nn as nn from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from optimus.utils import dist_utils, logging_utils +from optimus.utils import dist_utils, logging_utils, checkpoint_utils logger = logging.getLogger(__name__) @@ -81,6 +83,47 @@ class TrainingArguments(): self.save_limit = save_limit +class TrainerState(): + """ + State of a trainer. This contains information that needs to be persisted + across checkpoints, such as current epoch or current number of steps. + """ + def __init__(self): + self.current_epoch = 0 + + # this is weight updates, in total, over the whole training loop, with + # however many epochs; with gradient accumulation steps > 1, there + # will be more forward passes, but this is not captured here; global + # steps doesn't care about epochs, but is rather equal to epochs * + # steps_per_epoch + self.global_steps = 0 + + def to_json_file(self, file: Path): + """ + Save the current trainer state as a json file. + + Args: + file (Path): Path where the trainer state json file should be saved. + """ + print('saving json trainer state') + with open(file, 'w', encoding='utf-8') as f: + json.dump(vars(self), f, ensure_ascii=False, indent=4) + # TODO probably should just be to_json + + @classmethod + def from_json_file(cls, file: Path | str): + """ + Reload the state of a trainer from a json file. Note: It is very likely + that training *will not work as expected* if you resume training from a + checkpoint with a trainer state that has been manually altered. + + Args: + file (Path): Path to a json file (usually contained inside a + checkpoint directory), + """ + print('loading trainer state from json') #TODO + + class Trainer(): """ Generic PyTorch trainer implementation. @@ -112,17 +155,37 @@ class Trainer(): self.eval_dataloader = eval_dataloader self.tokenizer = tokenizer + # initialize an empty trainer state (if resuming from a checkpoint, + # will be updated later) + self.state = TrainerState() + if self.args.log_on_main_rank_only: + # TODO this doesn't seem to work? logger.addFilter(logging_utils.FilterMainRankOnly()) - def train(self) -> None: + def train(self, resume_from_checkpoint: bool|Path = False) -> None: """ Training loop of the trainer. + Args: + resume_from_checkpoint (bool or Path): Whether training should + resume from a checkpoint or start from scratch. If yes, then + training resumes from the checkpoint with the latest + modification time saved in `checkpoints_dir` (see + `TrainingArguments`). + """ + # number of examples in dataloader (globally, not per GPU); an example + # is a sequence of length context_len num_examples = len(self.train_dataloader) * self.args.per_device_batch_size + + # weight update steps per epoch (globally, not per GPU) num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps + + # total number of weight update steps max_steps = self.args.num_train_epochs * num_update_steps_per_epoch + + # global batch size global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size() fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @@ -132,27 +195,70 @@ class Trainer(): self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0) # scaler used for mixed precision fp16 training on GPU + # TODO, with bf16, you don't need a scaler self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16) + # initialize a Trainer state (will be overriden later if resuming from + # checkpoint) + self.state = TrainerState() + + # potentially resume from saved checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + last_checkpoint: Path | None = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir) + if last_checkpoint is not None: + logger.info(f"Resuming training from checkpoint {str(last_checkpoint)}") + + # TODO check what model.state_dict, opt.state_dict etc. really is, + # and see how to replicate for trainer state + # TODO possibly move trainer state and trainer args to a separate + # trainer utils, since state should most likely be held separate + # from behaviour (Trainer) anyway, and also because TrainerState and + # checkpoint_utils have a circular dependency + self.state = TrainerState.from_json_file(os.path.join('checkpoint-1500', 'trainer_state.json')) + + # load model from .pt + self.model.load_state_dict(torch.load(os.path.join('checkpoint-1500', 'model.pt'))) + + #trainer_state.to_json_file(os.path.join('checkpoint-1500', 'trainer_state.json')) + # TODO write this function and check it + # TODO also write cu.save_checkpoint() + # checkpoint_utils.load_from_checkpoint( + # last_checkpoint, + # self.state, + # self.model, + # self.optimizer, + # self.scheduler, + # #self.scaler, # ? + # ) + + self.model.train() + logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Total number of examples in train dataloader = {num_examples:,}") logger.info(f" Num epochs = {self.args.num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_batch_size:,}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Global batch size (w. distributed & accumulation) = {global_batch_size:,}") logger.info(f" Total optimization steps = {max_steps:,}") - self.model.train() - - # start training for num_train_epochs - for epoch in range(self.args.num_train_epochs): + # start training for num_train_epochs (or less if resumed from checkpoint) + for epoch in range(self.state.current_epoch, self.args.num_train_epochs): # needed for distributed sampler RNG state if hasattr(self.train_dataloader, "set_epoch"): self.train_dataloader.set_epoch(epoch) + # if resume from checkpoint, update learning rate + for step in range(0, self.state.global_steps): + self.args.lr_scheduler.step() + for step, inputs in enumerate(self.train_dataloader): + # if resume from checkpoint, skip the first batches of data in + # the train dataloader + if step < (self.state.global_steps % (self.state.current_epoch + 1)): + continue + inputs = inputs['input_ids'] inputs = inputs.to(self.args.device) @@ -164,7 +270,7 @@ class Trainer(): logits = logits[..., :-1, :].contiguous().view( -1, self.model.module.vocab_size - if dist_utils.is_dist_available_and_initialized() + if dist_utils.is_dist_available_and_initialized() # TODO this should probably be a hasattr(self.model, 'module') else self.model.vocab_size ) @@ -173,13 +279,16 @@ class Trainer(): tr_loss = loss.item() loss = loss / self.args.gradient_accumulation_steps # normalize to account for gradient accumulation + # TODO only for fp16, not fp32 or bf16 self.scaler.scale(loss).backward() # update only after gradient_accumulation_steps if (step + 1) % self.args.gradient_accumulation_steps == 0: - # Note: This will ignore the last few batches of the dataset, - # when the gradient accumulation steps are more than 1, and the - # number of batches doesn't cleanly divide by grad_acc_steps + # Note: This will ignore the last few batches of the dataset + # when grad_acc_steps > 1 and the number of batches isn't + # divisible by grad_acc_steps + + self.state.global_steps += 1 # gradient clipping self.scaler.unscale_(self.args.optimizer) @@ -193,11 +302,21 @@ class Trainer(): self.args.lr_scheduler.step() # lr = self.args.lr_scheduler.get_last_lr()[0] - if (step + 1) % self.args.log_steps * self.args.gradient_accumulation_steps == 0: + if step % self.args.log_steps == 0: logger.info(f"Loss is {tr_loss:,}") self.progress.update(1) + if step % self.args.save_steps == 0: + if dist_utils.get_rank() == 0: # TODO this only works for DDP, not FSDP + # save checkpoint + print(f'saving checkpoint at step {self.state.global_steps}') + os.makedirs(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}')) + self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json'))) + torch.save(self.model.state_dict(), + os.path.join(self.args.checkpoints_dir, + os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt'))) + def save_model(self, save_dir: Path) -> None: """ Save model and tokenizer to a directory. diff --git a/optimus/utils/checkpoint_utils.py b/optimus/utils/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bee6b833033cb182970ee5676b0b16fb25961c40 --- /dev/null +++ b/optimus/utils/checkpoint_utils.py @@ -0,0 +1,84 @@ +import os +import re +import logging +from pathlib import Path + +from optimus.trainer import TrainerState + + +logger = logging.getLogger(__name__) + + +# files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +OPTIMIZER_NAME_BIN = "optimizer.bin" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" +# FSDP_MODEL_NAME = "pytorch_model_fsdp" + + +def get_last_checkpoint_dir(dir: Path) -> Path | None: + """ + Get last saved (highest global step) checkpoint in a directory, matching the + pattern `checkpoint-<number>`, where number is the global step of training. + Each checkpoint should be in itself another directory. + + Args: + dir (Path): The checkpoints directory. + + Returns (Path or None): Returns the path to the last checkpoint, or `None` + if no valid checkpoints were found. + + """ + checkpoint_regex = re.compile(r"^checkpoint-(\d+)$") + checkpoints_dirs = os.listdir(dir) + checkpoints = [ + path + for path in checkpoints_dirs + if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path)) + ] + if len(checkpoints) == 0: + return None + return Path(os.path.join( + dir, + max(checkpoints, + key=lambda x: int(checkpoint_regex.search(x).groups()[0])) + )) + + +def load_from_checkpoint( + chkpt: Path, + trainer_state: TrainerState, + # model: nn.Module, + # optimizer: torch.optim.Optimizer, + # scheduler: torch.optim.lr_scheduler.LRScheduler, + ): + """ + Load the TODO from an existing checkpoint. This ensures training can + resume exactly as it was before checkpointing. + + Args: + chkpt (Path): Path to checkpoint directory. + trainer_state (TrainerState): Trainer state to load. + """ + trainer_state = TrainerState.from_json_file(os.path.join(chkpt, TRAINER_STATE_NAME)) + pass + + +def save_checkpoint( + chkpt: Path, + trainer_state: TrainerState, + ): + """ + Save TODO states as a checkpoint. This ensures training can be resumed + exactly as it was before saving. + + Args: + chkpt (Path): Path to checkpoint directory. + trainer_state (TrainerState): Trainer state to save. + """ + pass + # chkpt.mkdir(parents=True, exist_ok=True) + # trainer_state.to_json_file() diff --git a/training.py b/training.py index 9a6feea0068398aa9f0e1a235cce781e92a98e94..4f0932f62a34b8f08bb8a1ede80da77a0e8d9bf8 100644 --- a/training.py +++ b/training.py @@ -29,6 +29,9 @@ def main( seed: int = 42, log_on_main_rank_only: bool = True, + # checkpointing + resume_from_checkpoint: bool | Path = False, + # model + tokenizer model_config_path: Path = Path('config.json'), hf_tokenizer_name: str = 'gpt2', @@ -37,6 +40,9 @@ def main( # directory than the source code) data_dir: Path = Path('.'), + # TODO make a different scratch_dir for checkpoints, instead of re-using + # data_dir; similarly think about save_dir and logs_dir + # dataset dir (appended to data_dir) dataset_dir: Path = Path('gpt2_tokenized_wikitext103_no_seq_len'), @@ -68,6 +74,8 @@ def main( have its initial PyTorch seed set to `seed + process_rank`. log_on_main_rank_only (bool): Whether output should only be printed by the main rank (rank 0). + resume_from_checkpoint (bool|Path): Yes or a valid checkpoint path to + resume training from an existing checkpint. model_config_path (Path): Path to a config file which describes the model to be trained. hf_tokenizer_name (str): HuggingFace tokenizer name. @@ -117,20 +125,20 @@ def main( else: dataset_dir = data_dir / dataset_dir - if isinstance(checkpoints_dir, str): - checkpoints_dir = data_dir / Path(checkpoints_dir) - else: - checkpoints_dir = data_dir / checkpoints_dir + #if isinstance(checkpoints_dir, str): + # checkpoints_dir = data_dir / Path(checkpoints_dir) + #else: + # checkpoints_dir = data_dir / checkpoints_dir - if isinstance(log_dir, str): - log_dir = data_dir / Path(log_dir) - else: - log_dir = data_dir / log_dir + #if isinstance(log_dir, str): + # log_dir = data_dir / Path(log_dir) + #else: + # log_dir = data_dir / log_dir - if isinstance(save_dir, str): - save_dir = data_dir / Path(save_dir) - else: - save_dir = data_dir / save_dir + #if isinstance(save_dir, str): + # save_dir = data_dir / Path(save_dir) + #else: + # save_dir = data_dir / save_dir logger.info(f"Running with:\n" f"\t- batch size: {batch_size}\n" @@ -144,6 +152,7 @@ def main( f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"\t- seed: {seed}\n" f"\t- only main rank logs: {log_on_main_rank_only}\n" + f"\t- resume from checkpoint: {resume_from_checkpoint}\n" f"\t- model config file: {model_config_path}\n" f"\t- huggingface tokenizer: {hf_tokenizer_name}\n" f"\t- training data directory: {str(data_dir)}\n" @@ -266,7 +275,7 @@ def main( logger.info('Starting training...') - trainer.train() + trainer.train(resume_from_checkpoint) logger.info(f"Finished training! Saving model weights to '{str(save_dir)}'")