diff --git a/optimus/trainer.py b/optimus/trainer.py
index 00497ab257179c1ff179c0e49a20103f59296d97..3c22d7ebda404807de8c1c7550e01c68edddc15a 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -8,12 +8,22 @@ import torch.nn as nn
 from torch.utils.data import DataLoader
 from transformers.tokenization_utils_base import PreTrainedTokenizerBase
 
-from optimus.utils import dist_utils, logging_utils
+from optimus.utils import dist_utils, logging_utils, checkpoint_utils
 
 
 logger = logging.getLogger(__name__)
 
 
+# files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+OPTIMIZER_NAME_BIN = "optimizer.bin"
+SCHEDULER_NAME = "scheduler.pt"
+SCALER_NAME = "scaler.pt"
+FSDP_MODEL_NAME = "pytorch_model_fsdp"
+
+
 class TrainingArguments():
     """
     Training arguments class to hold important switches and knobs related to
@@ -81,6 +91,37 @@ class TrainingArguments():
         self.save_limit = save_limit
 
 
+class TrainerState():
+    """
+    State of a trainer. This contains information that needs to be persisted
+    across checkpoints, such as current epoch or current number of steps.
+    """
+    def __init__(self):
+        self.current_epoch = 0
+        self.global_steps = 0
+
+    def to_json_file(self, file: Path):
+        """
+        Save the current trainer state as a json file.
+
+        Args:
+            file (Path): Path where the trainer state json file should be saved.
+        """
+        pass # TODO
+
+    def from_json_file(self, file: Path):
+        """
+        Reload the state of a trainer from a json file. Note: It is very likely
+        that training *will not work as expected* if you resume training from a
+        checkpoint with a trainer state that has been manually altered.
+
+        Args:
+            file (Path): Path to a json file (usually contained inside a
+            checkpoint directory),
+        """
+        pass # TODO
+
+
 class Trainer():
     """
     Generic PyTorch trainer implementation.
@@ -112,6 +153,10 @@ class Trainer():
         self.eval_dataloader = eval_dataloader
         self.tokenizer = tokenizer
 
+        # initialize an empty trainer state (if resuming from a checkpoint,
+        # will be updated later)
+        self.state = TrainerState()
+
         if self.args.log_on_main_rank_only:
             logger.addFilter(logging_utils.FilterMainRankOnly())
 
@@ -119,6 +164,13 @@ class Trainer():
         """
         Training loop of the trainer.
 
+        Args:
+            resume_from_checkpoint (bool or Path): Whether training should
+                resume from a checkpoint or start from scratch. If yes, then
+                training resumes from the checkpoint with the latest
+                modification time saved in `checkpoints_dir` (see
+                `TrainingArguments`).
+
         """
         num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
         num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
@@ -134,6 +186,26 @@ class Trainer():
         # scaler used for mixed precision fp16 training on GPU
         self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
 
+        epochs_trained = 0
+        steps_trained_in_current_epoch = 0
+
+        # potentially resume from saved checkpoint
+        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+            last_checkpoint: Path = checkpoint_utils.get_last_checkpoint_dir(self.args.checkpoints_dir)
+            if last_checkpoint is not None:
+                logger.info(f"Resuming training from checkpoint {str(last_checkpoint)}")
+
+            checkpoint_utils.load_from_checkpoint(
+                last_checkpoint,
+                self.state,
+                self.model,
+                self.optimizer,
+                self.scheduler,
+                #self.scaler, # ?
+            )
+
+        self.model.train()
+
         logger.info("***** Running training *****")
         logger.info(f"  Num examples = {num_examples:,}")
         logger.info(f"  Num epochs = {self.args.num_train_epochs:,}")
@@ -142,16 +214,24 @@ class Trainer():
         logger.info(f"  Global batch size (w. distributed & accumulation) = {global_batch_size:,}")
         logger.info(f"  Total optimization steps = {max_steps:,}")
 
-        self.model.train()
-
-        # start training for num_train_epochs
-        for epoch in range(self.args.num_train_epochs):
+        # start training for num_train_epochs (or less if resumed from checkpoint)
+        for epoch in range(self.state.current_epoch, self.args.num_train_epochs):
 
             # needed for distributed sampler RNG state
             if hasattr(self.train_dataloader, "set_epoch"):
                 self.train_dataloader.set_epoch(epoch)
 
+            # resume from checkpoint, update steps
+            # TODO grad_acc_steps for this?
+            for step in range(0, self.state.global_steps):
+                self.args.lr_scheduler.step()
+
             for step, inputs in enumerate(self.train_dataloader):
+                steps_trained_in_current_epoch += 1
+
+                # check if resume from training
+                if steps_trained_in_current_epoch < self.state.global_steps:
+                    continue
 
                 inputs = inputs['input_ids']
                 inputs = inputs.to(self.args.device)
@@ -164,7 +244,7 @@ class Trainer():
                     logits = logits[..., :-1, :].contiguous().view(
                         -1,
                         self.model.module.vocab_size
-                        if dist_utils.is_dist_available_and_initialized()
+                        if dist_utils.is_dist_available_and_initialized() # TODO this should probably be a hasattr(self.model, 'module')
                         else self.model.vocab_size
                     )
 
diff --git a/optimus/utils/checkpoint_utils.py b/optimus/utils/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c4f98a59041f99ef8a54bb34213ab2d181480dd
--- /dev/null
+++ b/optimus/utils/checkpoint_utils.py
@@ -0,0 +1,36 @@
+import os
+import re
+import logging
+from pathlib import Path
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_last_checkpoint_dir(dir: Path):
+    """
+    Get last saved (highest global step) checkpoint in a directory, matching the
+    pattern `checkpoint-<number>`, where number is the global step of training.
+    Each checkpoint should be in itself another directory.
+
+    Args:
+        dir (Path): The checkpoints directory.
+
+    Returns (Path or None): Returns the path to the last checkpoint, or `None`
+        if no valid checkpoints were found.
+
+    """
+    checkpoint_regex = re.compile(r"^checkpoint-(\d+)$")
+    checkpoints_dirs = os.listdir(dir)
+    checkpoints = [
+        path
+        for path in checkpoints_dirs
+        if checkpoint_regex.search(path) is not None and os.path.isdir(os.path.join(dir, path))
+    ]
+    if len(checkpoints) == 0:
+        return None
+    return os.path.join(
+        dir,
+        max(checkpoints,
+            key=lambda x: int(checkpoint_regex.search(x).groups()[0]))
+    )