Skip to content
Snippets Groups Projects

Draft: WIP Training checkpointing

Open Alexandru-Mihai GHERGHESCU requested to merge training_checkpointing into main
3 files
+ 200
8
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 84
0
import os
import re
import logging
from pathlib import Path
from optimus.trainer import TrainerState
logger = logging.getLogger(__name__)
# files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
# FSDP_MODEL_NAME = "pytorch_model_fsdp"
def get_last_checkpoint_dir(dir: Path) -> Path | None:
"""
Get last saved (highest global step) checkpoint in a directory, matching the
pattern `checkpoint-<number>`, where number is the global step of training.
Each checkpoint should be in itself another directory.
Args:
dir (Path): The checkpoints directory.
Returns (Path or None): Returns the path to the last checkpoint, or `None`
if no valid checkpoints were found.
"""
checkpoint_regex = re.compile(r"^checkpoint-(\d+)$")
checkpoints_dirs = os.listdir(dir)
checkpoints = [
path
for path in checkpoints_dirs
if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path))
]
if len(checkpoints) == 0:
return None
return Path(os.path.join(
dir,
max(checkpoints,
key=lambda x: int(checkpoint_regex.search(x).groups()[0]))
))
def load_from_checkpoint(
chkpt: Path,
trainer_state: TrainerState,
# model: nn.Module,
# optimizer: torch.optim.Optimizer,
# scheduler: torch.optim.lr_scheduler.LRScheduler,
):
"""
Load the TODO from an existing checkpoint. This ensures training can
resume exactly as it was before checkpointing.
Args:
chkpt (Path): Path to checkpoint directory.
trainer_state (TrainerState): Trainer state to load.
"""
trainer_state = TrainerState.from_json_file(os.path.join(chkpt, TRAINER_STATE_NAME))
pass
def save_checkpoint(
chkpt: Path,
trainer_state: TrainerState,
):
"""
Save TODO states as a checkpoint. This ensures training can be resumed
exactly as it was before saving.
Args:
chkpt (Path): Path to checkpoint directory.
trainer_state (TrainerState): Trainer state to save.
"""
pass
# chkpt.mkdir(parents=True, exist_ok=True)
# trainer_state.to_json_file()
Loading