Skip to content
Snippets Groups Projects
Unverified Commit 32ac537c authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

WIP training checkpointing

parent cb1a7974
No related branches found
No related tags found
No related merge requests found
Pipeline #72414 passed
import os
import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -8,7 +10,7 @@ import torch.nn as nn ...@@ -8,7 +10,7 @@ import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from optimus.utils import dist_utils, logging_utils from optimus.utils import dist_utils, logging_utils, checkpoint_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -81,6 +83,47 @@ class TrainingArguments(): ...@@ -81,6 +83,47 @@ class TrainingArguments():
self.save_limit = save_limit self.save_limit = save_limit
class TrainerState():
"""
State of a trainer. This contains information that needs to be persisted
across checkpoints, such as current epoch or current number of steps.
"""
def __init__(self):
self.current_epoch = 0
# this is weight updates, in total, over the whole training loop, with
# however many epochs; with gradient accumulation steps > 1, there
# will be more forward passes, but this is not captured here; global
# steps doesn't care about epochs, but is rather equal to epochs *
# steps_per_epoch
self.global_steps = 0
def to_json_file(self, file: Path):
"""
Save the current trainer state as a json file.
Args:
file (Path): Path where the trainer state json file should be saved.
"""
print('saving json trainer state')
with open(file, 'w', encoding='utf-8') as f:
json.dump(vars(self), f, ensure_ascii=False, indent=4)
# TODO probably should just be to_json
@classmethod
def from_json_file(cls, file: Path | str):
"""
Reload the state of a trainer from a json file. Note: It is very likely
that training *will not work as expected* if you resume training from a
checkpoint with a trainer state that has been manually altered.
Args:
file (Path): Path to a json file (usually contained inside a
checkpoint directory),
"""
print('loading trainer state from json') #TODO
class Trainer(): class Trainer():
""" """
Generic PyTorch trainer implementation. Generic PyTorch trainer implementation.
...@@ -112,13 +155,24 @@ class Trainer(): ...@@ -112,13 +155,24 @@ class Trainer():
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.tokenizer = tokenizer self.tokenizer = tokenizer
# initialize an empty trainer state (if resuming from a checkpoint,
# will be updated later)
self.state = TrainerState()
if self.args.log_on_main_rank_only: if self.args.log_on_main_rank_only:
logger.addFilter(logging_utils.FilterMainRankOnly()) logger.addFilter(logging_utils.FilterMainRankOnly())
def train(self) -> None: def train(self, resume_from_checkpoint: bool|Path = False) -> None:
""" """
Training loop of the trainer. Training loop of the trainer.
Args:
resume_from_checkpoint (bool or Path): Whether training should
resume from a checkpoint or start from scratch. If yes, then
training resumes from the checkpoint with the latest
modification time saved in `checkpoints_dir` (see
`TrainingArguments`).
""" """
num_examples = len(self.train_dataloader) * self.args.per_device_batch_size num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
...@@ -134,6 +188,44 @@ class Trainer(): ...@@ -134,6 +188,44 @@ class Trainer():
# scaler used for mixed precision fp16 training on GPU # scaler used for mixed precision fp16 training on GPU
self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16) self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
# initialize a Trainer state (will be overriden later if resuming from
# checkpoint)
self.state = TrainerState()
epochs_trained = 0
steps_trained_in_current_epoch = 0
# potentially resume from saved checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
last_checkpoint: Path | None = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir)
if last_checkpoint is not None:
logger.info(f"Resuming training from checkpoint {str(last_checkpoint)}")
# TODO check what model.state_dict, opt.state_dict etc. really is,
# and see how to replicate for trainer state
# TODO possibly move trainer state and trainer args to a separate
# trainer utils, since state should most likely be held separate
# from behaviour (Trainer) anyway, and also because TrainerState and
# checkpoint_utils have a circular dependency
self.state = TrainerState.from_json_file(os.path.join('checkpoint-1500', 'trainer_state.json'))
# load model from .pt
self.model.load_state_dict(torch.load(os.path.join('checkpoint-1500', 'model.pt')))
#trainer_state.to_json_file(os.path.join('checkpoint-1500', 'trainer_state.json'))
# TODO write this function and check it
# TODO also write cu.save_checkpoint()
# checkpoint_utils.load_from_checkpoint(
# last_checkpoint,
# self.state,
# self.model,
# self.optimizer,
# self.scheduler,
# #self.scaler, # ?
# )
self.model.train()
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num epochs = {self.args.num_train_epochs:,}") logger.info(f" Num epochs = {self.args.num_train_epochs:,}")
...@@ -142,17 +234,24 @@ class Trainer(): ...@@ -142,17 +234,24 @@ class Trainer():
logger.info(f" Global batch size (w. distributed & accumulation) = {global_batch_size:,}") logger.info(f" Global batch size (w. distributed & accumulation) = {global_batch_size:,}")
logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Total optimization steps = {max_steps:,}")
self.model.train() # start training for num_train_epochs (or less if resumed from checkpoint)
for epoch in range(self.state.current_epoch, self.args.num_train_epochs):
# start training for num_train_epochs
for epoch in range(self.args.num_train_epochs):
# needed for distributed sampler RNG state # needed for distributed sampler RNG state
if hasattr(self.train_dataloader, "set_epoch"): if hasattr(self.train_dataloader, "set_epoch"):
self.train_dataloader.set_epoch(epoch) self.train_dataloader.set_epoch(epoch)
# resume from checkpoint, update learning rate
for step in range(0, self.state.global_steps):
self.args.lr_scheduler.step()
for step, inputs in enumerate(self.train_dataloader): for step, inputs in enumerate(self.train_dataloader):
# check if resume from checkpoint and skip the first batches of
# data in the dataloader
if steps_trained_in_current_epoch < (self.state.global_steps % (self.state.current_epoch + 1)):
continue
inputs = inputs['input_ids'] inputs = inputs['input_ids']
inputs = inputs.to(self.args.device) inputs = inputs.to(self.args.device)
...@@ -164,7 +263,7 @@ class Trainer(): ...@@ -164,7 +263,7 @@ class Trainer():
logits = logits[..., :-1, :].contiguous().view( logits = logits[..., :-1, :].contiguous().view(
-1, -1,
self.model.module.vocab_size self.model.module.vocab_size
if dist_utils.is_dist_available_and_initialized() if dist_utils.is_dist_available_and_initialized() # TODO this should probably be a hasattr(self.model, 'module')
else self.model.vocab_size else self.model.vocab_size
) )
...@@ -177,9 +276,11 @@ class Trainer(): ...@@ -177,9 +276,11 @@ class Trainer():
# update only after gradient_accumulation_steps # update only after gradient_accumulation_steps
if (step + 1) % self.args.gradient_accumulation_steps == 0: if (step + 1) % self.args.gradient_accumulation_steps == 0:
# Note: This will ignore the last few batches of the dataset, # Note: This will ignore the last few batches of the dataset
# when the gradient accumulation steps are more than 1, and the # when grad_acc_steps > 1 and the number of batches isn't
# number of batches doesn't cleanly divide by grad_acc_steps # divisible by grad_acc_steps
steps_trained_in_current_epoch += 1
# gradient clipping # gradient clipping
self.scaler.unscale_(self.args.optimizer) self.scaler.unscale_(self.args.optimizer)
...@@ -193,11 +294,21 @@ class Trainer(): ...@@ -193,11 +294,21 @@ class Trainer():
self.args.lr_scheduler.step() self.args.lr_scheduler.step()
# lr = self.args.lr_scheduler.get_last_lr()[0] # lr = self.args.lr_scheduler.get_last_lr()[0]
if (step + 1) % self.args.log_steps * self.args.gradient_accumulation_steps == 0: if steps_trained_in_current_epoch % self.args.log_steps == 0:
logger.info(f"Loss is {tr_loss:,}") logger.info(f"Loss is {tr_loss:,}")
self.progress.update(1) self.progress.update(1)
if steps_trained_in_current_epoch % self.args.save_steps == 0:
if dist_utils.get_rank() == 0: # TODO this only works for DDP
# save checkpoint
print(f'saving checkpoint at step {steps_trained_in_current_epoch}')
os.makedirs(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'))
self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
torch.save(self.model.state_dict(),
os.path.join(self.args.checkpoints_dir,
os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
def save_model(self, save_dir: Path) -> None: def save_model(self, save_dir: Path) -> None:
""" """
Save model and tokenizer to a directory. Save model and tokenizer to a directory.
......
import os
import re
import logging
from pathlib import Path
from optimus.trainer import TrainerState
logger = logging.getLogger(__name__)
# files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
# FSDP_MODEL_NAME = "pytorch_model_fsdp"
def get_last_checkpoint_dir(dir: Path) -> Path | None:
"""
Get last saved (highest global step) checkpoint in a directory, matching the
pattern `checkpoint-<number>`, where number is the global step of training.
Each checkpoint should be in itself another directory.
Args:
dir (Path): The checkpoints directory.
Returns (Path or None): Returns the path to the last checkpoint, or `None`
if no valid checkpoints were found.
"""
checkpoint_regex = re.compile(r"^checkpoint-(\d+)$")
checkpoints_dirs = os.listdir(dir)
checkpoints = [
path
for path in checkpoints_dirs
if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path))
]
if len(checkpoints) == 0:
return None
return Path(os.path.join(
dir,
max(checkpoints,
key=lambda x: int(checkpoint_regex.search(x).groups()[0]))
))
def load_from_checkpoint(
chkpt: Path,
trainer_state: TrainerState,
# model: nn.Module,
# optimizer: torch.optim.Optimizer,
# scheduler: torch.optim.lr_scheduler.LRScheduler,
):
"""
Load the TODO from an existing checkpoint. This ensures training can
resume exactly as it was before checkpointing.
Args:
chkpt (Path): Path to checkpoint directory.
trainer_state (TrainerState): Trainer state to load.
"""
trainer_state = TrainerState.from_json_file(os.path.join(chkpt, TRAINER_STATE_NAME))
pass
def save_checkpoint(
chkpt: Path,
trainer_state: TrainerState,
):
"""
Save TODO states as a checkpoint. This ensures training can be resumed
exactly as it was before saving.
Args:
chkpt (Path): Path to checkpoint directory.
trainer_state (TrainerState): Trainer state to save.
"""
pass
# chkpt.mkdir(parents=True, exist_ok=True)
# trainer_state.to_json_file()
...@@ -29,6 +29,9 @@ def main( ...@@ -29,6 +29,9 @@ def main(
seed: int = 42, seed: int = 42,
log_on_main_rank_only: bool = True, log_on_main_rank_only: bool = True,
# checkpointing
resume_from_checkpoint: bool | Path = False,
# model + tokenizer # model + tokenizer
model_config_path: Path = Path('config.json'), model_config_path: Path = Path('config.json'),
hf_tokenizer_name: str = 'gpt2', hf_tokenizer_name: str = 'gpt2',
...@@ -37,6 +40,9 @@ def main( ...@@ -37,6 +40,9 @@ def main(
# directory than the source code) # directory than the source code)
data_dir: Path = Path('.'), data_dir: Path = Path('.'),
# TODO make a different scratch_dir for checkpoints, instead of re-using
# data_dir; similarly think about save_dir and logs_dir
# dataset dir (appended to data_dir) # dataset dir (appended to data_dir)
dataset_dir: Path = Path('gpt2_tokenized_wikitext103_no_seq_len'), dataset_dir: Path = Path('gpt2_tokenized_wikitext103_no_seq_len'),
...@@ -68,6 +74,8 @@ def main( ...@@ -68,6 +74,8 @@ def main(
have its initial PyTorch seed set to `seed + process_rank`. have its initial PyTorch seed set to `seed + process_rank`.
log_on_main_rank_only (bool): Whether output should only be printed by log_on_main_rank_only (bool): Whether output should only be printed by
the main rank (rank 0). the main rank (rank 0).
resume_from_checkpoint (bool|Path): Yes or a valid checkpoint path to
resume training from an existing checkpint.
model_config_path (Path): Path to a config file which describes the model_config_path (Path): Path to a config file which describes the
model to be trained. model to be trained.
hf_tokenizer_name (str): HuggingFace tokenizer name. hf_tokenizer_name (str): HuggingFace tokenizer name.
...@@ -117,20 +125,20 @@ def main( ...@@ -117,20 +125,20 @@ def main(
else: else:
dataset_dir = data_dir / dataset_dir dataset_dir = data_dir / dataset_dir
if isinstance(checkpoints_dir, str): #if isinstance(checkpoints_dir, str):
checkpoints_dir = data_dir / Path(checkpoints_dir) # checkpoints_dir = data_dir / Path(checkpoints_dir)
else: #else:
checkpoints_dir = data_dir / checkpoints_dir # checkpoints_dir = data_dir / checkpoints_dir
if isinstance(log_dir, str): #if isinstance(log_dir, str):
log_dir = data_dir / Path(log_dir) # log_dir = data_dir / Path(log_dir)
else: #else:
log_dir = data_dir / log_dir # log_dir = data_dir / log_dir
if isinstance(save_dir, str): #if isinstance(save_dir, str):
save_dir = data_dir / Path(save_dir) # save_dir = data_dir / Path(save_dir)
else: #else:
save_dir = data_dir / save_dir # save_dir = data_dir / save_dir
logger.info(f"Running with:\n" logger.info(f"Running with:\n"
f"\t- batch size: {batch_size}\n" f"\t- batch size: {batch_size}\n"
...@@ -144,6 +152,7 @@ def main( ...@@ -144,6 +152,7 @@ def main(
f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
f"\t- seed: {seed}\n" f"\t- seed: {seed}\n"
f"\t- only main rank logs: {log_on_main_rank_only}\n" f"\t- only main rank logs: {log_on_main_rank_only}\n"
f"\t- resume from checkpoint: {resume_from_checkpoint}\n"
f"\t- model config file: {model_config_path}\n" f"\t- model config file: {model_config_path}\n"
f"\t- huggingface tokenizer: {hf_tokenizer_name}\n" f"\t- huggingface tokenizer: {hf_tokenizer_name}\n"
f"\t- training data directory: {str(data_dir)}\n" f"\t- training data directory: {str(data_dir)}\n"
...@@ -266,7 +275,7 @@ def main( ...@@ -266,7 +275,7 @@ def main(
logger.info('Starting training...') logger.info('Starting training...')
trainer.train() trainer.train(resume_from_checkpoint)
logger.info(f"Finished training! Saving model weights to '{str(save_dir)}'") logger.info(f"Finished training! Saving model weights to '{str(save_dir)}'")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment