diff --git a/optimus/trainer.py b/optimus/trainer.py
index 00497ab257179c1ff179c0e49a20103f59296d97..b25955e8048e6f69e6b4c605c41aacfb6915c5d1 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -1,3 +1,5 @@
+import os
+import json
 import logging
 from pathlib import Path
 from typing import Optional
@@ -8,7 +10,7 @@ import torch.nn as nn
 from torch.utils.data import DataLoader
 from transformers.tokenization_utils_base import PreTrainedTokenizerBase
 
-from optimus.utils import dist_utils, logging_utils
+from optimus.utils import dist_utils, logging_utils, checkpoint_utils
 
 
 logger = logging.getLogger(__name__)
@@ -81,6 +83,47 @@ class TrainingArguments():
         self.save_limit = save_limit
 
 
+class TrainerState():
+    """
+    State of a trainer. This contains information that needs to be persisted
+    across checkpoints, such as current epoch or current number of steps.
+    """
+    def __init__(self):
+        self.current_epoch = 0
+
+        # this is weight updates, in total, over the whole training loop, with
+        # however many epochs; with gradient accumulation steps > 1, there
+        # will be more forward passes, but this is not captured here; global
+        # steps doesn't care about epochs, but is rather equal to epochs *
+        # steps_per_epoch
+        self.global_steps = 0
+
+    def to_json_file(self, file: Path):
+        """
+        Save the current trainer state as a json file.
+
+        Args:
+            file (Path): Path where the trainer state json file should be saved.
+        """
+        print('saving json trainer state')
+        with open(file, 'w', encoding='utf-8') as f:
+            json.dump(vars(self), f, ensure_ascii=False, indent=4)
+        # TODO probably should just be to_json
+
+    @classmethod
+    def from_json_file(cls, file: Path | str):
+        """
+        Reload the state of a trainer from a json file. Note: It is very likely
+        that training *will not work as expected* if you resume training from a
+        checkpoint with a trainer state that has been manually altered.
+
+        Args:
+            file (Path): Path to a json file (usually contained inside a
+            checkpoint directory),
+        """
+        print('loading trainer state from json') #TODO
+
+
 class Trainer():
     """
     Generic PyTorch trainer implementation.
@@ -112,17 +155,37 @@ class Trainer():
         self.eval_dataloader = eval_dataloader
         self.tokenizer = tokenizer
 
+        # initialize an empty trainer state (if resuming from a checkpoint,
+        # will be updated later)
+        self.state = TrainerState()
+
         if self.args.log_on_main_rank_only:
+            # TODO this doesn't seem to work?
             logger.addFilter(logging_utils.FilterMainRankOnly())
 
-    def train(self) -> None:
+    def train(self, resume_from_checkpoint: bool|Path = False) -> None:
         """
         Training loop of the trainer.
 
+        Args:
+            resume_from_checkpoint (bool or Path): Whether training should
+                resume from a checkpoint or start from scratch. If yes, then
+                training resumes from the checkpoint with the latest
+                modification time saved in `checkpoints_dir` (see
+                `TrainingArguments`).
+
         """
+        # number of examples in dataloader (globally, not per GPU); an example
+        # is a sequence of length context_len
         num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
+
+        # weight update steps per epoch (globally, not per GPU)
         num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
+
+        # total number of weight update steps
         max_steps = self.args.num_train_epochs * num_update_steps_per_epoch
+
+        # global batch size
         global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size()
 
         fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@@ -132,27 +195,70 @@ class Trainer():
         self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0)
 
         # scaler used for mixed precision fp16 training on GPU
+        # TODO, with bf16, you don't need a scaler
         self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
 
+        # initialize a Trainer state (will be overriden later if resuming from
+        # checkpoint)
+        self.state = TrainerState()
+
+        # potentially resume from saved checkpoint
+        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+            last_checkpoint: Path | None = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir)
+            if last_checkpoint is not None:
+                logger.info(f"Resuming training from checkpoint {str(last_checkpoint)}")
+
+            # TODO check what model.state_dict, opt.state_dict etc. really is,
+            # and see how to replicate for trainer state
+            # TODO possibly move trainer state and trainer args to a separate
+            # trainer utils, since state should most likely be held separate
+            # from behaviour (Trainer) anyway, and also because TrainerState and
+            # checkpoint_utils have a circular dependency
+            self.state = TrainerState.from_json_file(os.path.join('checkpoint-1500', 'trainer_state.json'))
+
+            # load model from .pt
+            self.model.load_state_dict(torch.load(os.path.join('checkpoint-1500', 'model.pt')))
+
+            #trainer_state.to_json_file(os.path.join('checkpoint-1500', 'trainer_state.json'))
+            # TODO write this function and check it
+            # TODO also write cu.save_checkpoint()
+            # checkpoint_utils.load_from_checkpoint(
+            #     last_checkpoint,
+            #     self.state,
+            #     self.model,
+            #     self.optimizer,
+            #     self.scheduler,
+            #     #self.scaler, # ?
+            # )
+
+        self.model.train()
+
         logger.info("***** Running training *****")
-        logger.info(f"  Num examples = {num_examples:,}")
+        logger.info(f"  Total number of examples in train dataloader = {num_examples:,}")
         logger.info(f"  Num epochs = {self.args.num_train_epochs:,}")
         logger.info(f"  Instantaneous batch size per device = {self.args.per_device_batch_size:,}")
         logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
         logger.info(f"  Global batch size (w. distributed & accumulation) = {global_batch_size:,}")
         logger.info(f"  Total optimization steps = {max_steps:,}")
 
-        self.model.train()
-
-        # start training for num_train_epochs
-        for epoch in range(self.args.num_train_epochs):
+        # start training for num_train_epochs (or less if resumed from checkpoint)
+        for epoch in range(self.state.current_epoch, self.args.num_train_epochs):
 
             # needed for distributed sampler RNG state
             if hasattr(self.train_dataloader, "set_epoch"):
                 self.train_dataloader.set_epoch(epoch)
 
+            # if resume from checkpoint, update learning rate
+            for step in range(0, self.state.global_steps):
+                self.args.lr_scheduler.step()
+
             for step, inputs in enumerate(self.train_dataloader):
 
+                # if resume from checkpoint, skip the first batches of data in
+                # the train dataloader
+                if step < (self.state.global_steps % (self.state.current_epoch + 1)):
+                    continue
+
                 inputs = inputs['input_ids']
                 inputs = inputs.to(self.args.device)
 
@@ -164,7 +270,7 @@ class Trainer():
                     logits = logits[..., :-1, :].contiguous().view(
                         -1,
                         self.model.module.vocab_size
-                        if dist_utils.is_dist_available_and_initialized()
+                        if dist_utils.is_dist_available_and_initialized() # TODO this should probably be a hasattr(self.model, 'module')
                         else self.model.vocab_size
                     )
 
@@ -173,13 +279,16 @@ class Trainer():
                     tr_loss = loss.item()
                     loss = loss / self.args.gradient_accumulation_steps # normalize to account for gradient accumulation
 
+                # TODO only for fp16, not fp32 or bf16
                 self.scaler.scale(loss).backward()
 
                 # update only after gradient_accumulation_steps
                 if (step + 1) % self.args.gradient_accumulation_steps == 0:
-                    # Note: This will ignore the last few batches of the dataset,
-                    # when the gradient accumulation steps are more than 1, and the
-                    # number of batches doesn't cleanly divide by grad_acc_steps
+                    # Note: This will ignore the last few batches of the dataset
+                    # when grad_acc_steps > 1 and the number of batches isn't
+                    # divisible by grad_acc_steps
+
+                    self.state.global_steps += 1
 
                     # gradient clipping
                     self.scaler.unscale_(self.args.optimizer)
@@ -193,11 +302,21 @@ class Trainer():
                     self.args.lr_scheduler.step()
                     # lr = self.args.lr_scheduler.get_last_lr()[0]
 
-                    if (step + 1) % self.args.log_steps * self.args.gradient_accumulation_steps == 0:
+                    if step % self.args.log_steps == 0:
                         logger.info(f"Loss is {tr_loss:,}")
 
                     self.progress.update(1)
 
+                    if step % self.args.save_steps == 0:
+                        if dist_utils.get_rank() == 0: # TODO this only works for DDP, not FSDP
+                            # save checkpoint
+                            print(f'saving checkpoint at step {self.state.global_steps}')
+                            os.makedirs(os.path.join(self.args.checkpoints_dir, f'checkpoint-{self.state.global_steps}'))
+                            self.state.to_json_file(os.path.join(self.args.checkpoints_dir, os.path.join(f'checkpoint-{self.state.global_steps}', 'trainer_state.json')))
+                            torch.save(self.model.state_dict(),
+                                       os.path.join(self.args.checkpoints_dir,
+                                       os.path.join(f'checkpoint-{self.state.global_steps}', 'model.pt')))
+
     def save_model(self, save_dir: Path) -> None:
         """
         Save model and tokenizer to a directory.
diff --git a/optimus/utils/checkpoint_utils.py b/optimus/utils/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bee6b833033cb182970ee5676b0b16fb25961c40
--- /dev/null
+++ b/optimus/utils/checkpoint_utils.py
@@ -0,0 +1,84 @@
+import os
+import re
+import logging
+from pathlib import Path
+
+from optimus.trainer import TrainerState
+
+
+logger = logging.getLogger(__name__)
+
+
+# files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+OPTIMIZER_NAME_BIN = "optimizer.bin"
+SCHEDULER_NAME = "scheduler.pt"
+SCALER_NAME = "scaler.pt"
+# FSDP_MODEL_NAME = "pytorch_model_fsdp"
+
+
+def get_last_checkpoint_dir(dir: Path) -> Path | None:
+    """
+    Get last saved (highest global step) checkpoint in a directory, matching the
+    pattern `checkpoint-<number>`, where number is the global step of training.
+    Each checkpoint should be in itself another directory.
+
+    Args:
+        dir (Path): The checkpoints directory.
+
+    Returns (Path or None): Returns the path to the last checkpoint, or `None`
+        if no valid checkpoints were found.
+
+    """
+    checkpoint_regex = re.compile(r"^checkpoint-(\d+)$")
+    checkpoints_dirs = os.listdir(dir)
+    checkpoints = [
+        path
+        for path in checkpoints_dirs
+        if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path))
+    ]
+    if len(checkpoints) == 0:
+        return None
+    return Path(os.path.join(
+        dir,
+        max(checkpoints,
+            key=lambda x: int(checkpoint_regex.search(x).groups()[0]))
+    ))
+
+
+def load_from_checkpoint(
+        chkpt: Path,
+        trainer_state: TrainerState,
+        # model: nn.Module,
+        # optimizer: torch.optim.Optimizer,
+        # scheduler: torch.optim.lr_scheduler.LRScheduler,
+    ):
+    """
+    Load the TODO from an existing checkpoint. This ensures training can
+    resume exactly as it was before checkpointing.
+
+    Args:
+        chkpt (Path): Path to checkpoint directory.
+        trainer_state (TrainerState): Trainer state to load.
+    """
+    trainer_state = TrainerState.from_json_file(os.path.join(chkpt, TRAINER_STATE_NAME))
+    pass
+
+
+def save_checkpoint(
+        chkpt: Path,
+        trainer_state: TrainerState,
+    ):
+    """
+    Save TODO states as a checkpoint. This ensures training can be resumed
+    exactly as it was before saving.
+
+    Args:
+        chkpt (Path): Path to checkpoint directory.
+        trainer_state (TrainerState): Trainer state to save.
+    """
+    pass
+    # chkpt.mkdir(parents=True, exist_ok=True)
+    # trainer_state.to_json_file()
diff --git a/training.py b/training.py
index 9a6feea0068398aa9f0e1a235cce781e92a98e94..4f0932f62a34b8f08bb8a1ede80da77a0e8d9bf8 100644
--- a/training.py
+++ b/training.py
@@ -29,6 +29,9 @@ def main(
     seed: int = 42,
     log_on_main_rank_only: bool = True,
 
+    # checkpointing
+    resume_from_checkpoint: bool | Path = False,
+
     # model + tokenizer
     model_config_path: Path = Path('config.json'),
     hf_tokenizer_name: str = 'gpt2',
@@ -37,6 +40,9 @@ def main(
     # directory than the source code)
     data_dir: Path = Path('.'),
 
+    # TODO make a different scratch_dir for checkpoints, instead of re-using
+    # data_dir; similarly think about save_dir and logs_dir
+
     # dataset dir (appended to data_dir)
     dataset_dir: Path = Path('gpt2_tokenized_wikitext103_no_seq_len'),
 
@@ -68,6 +74,8 @@ def main(
             have its initial PyTorch seed set to `seed + process_rank`.
         log_on_main_rank_only (bool): Whether output should only be printed by
             the main rank (rank 0).
+        resume_from_checkpoint (bool|Path): Yes or a valid checkpoint path to
+            resume training from an existing checkpint.
         model_config_path (Path): Path to a config file which describes the
             model to be trained.
         hf_tokenizer_name (str): HuggingFace tokenizer name.
@@ -117,20 +125,20 @@ def main(
     else:
         dataset_dir = data_dir / dataset_dir
 
-    if isinstance(checkpoints_dir, str):
-        checkpoints_dir = data_dir / Path(checkpoints_dir)
-    else:
-        checkpoints_dir = data_dir / checkpoints_dir
+    #if isinstance(checkpoints_dir, str):
+    #    checkpoints_dir = data_dir / Path(checkpoints_dir)
+    #else:
+    #    checkpoints_dir = data_dir / checkpoints_dir
 
-    if isinstance(log_dir, str):
-        log_dir = data_dir / Path(log_dir)
-    else:
-        log_dir = data_dir / log_dir
+    #if isinstance(log_dir, str):
+    #    log_dir = data_dir / Path(log_dir)
+    #else:
+    #    log_dir = data_dir / log_dir
 
-    if isinstance(save_dir, str):
-        save_dir = data_dir / Path(save_dir)
-    else:
-        save_dir = data_dir / save_dir
+    #if isinstance(save_dir, str):
+    #    save_dir = data_dir / Path(save_dir)
+    #else:
+    #    save_dir = data_dir / save_dir
 
     logger.info(f"Running with:\n"
         f"\t- batch size: {batch_size}\n"
@@ -144,6 +152,7 @@ def main(
         f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
         f"\t- seed: {seed}\n"
         f"\t- only main rank logs: {log_on_main_rank_only}\n"
+        f"\t- resume from checkpoint: {resume_from_checkpoint}\n"
         f"\t- model config file: {model_config_path}\n"
         f"\t- huggingface tokenizer: {hf_tokenizer_name}\n"
         f"\t- training data directory: {str(data_dir)}\n"
@@ -266,7 +275,7 @@ def main(
 
     logger.info('Starting training...')
 
-    trainer.train()
+    trainer.train(resume_from_checkpoint)
 
     logger.info(f"Finished training! Saving model weights to '{str(save_dir)}'")