diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py
index 116675bbf3bab2708391abd773166a6818051e33..fc11a93fafa6c8fb6e3a1f00281ca05b2b352983 100644
--- a/optimus/models/optimus.py
+++ b/optimus/models/optimus.py
@@ -4,6 +4,7 @@ import json
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
 
 
 class OptimusConfig():
@@ -51,6 +52,12 @@ class OptimusConfig():
             projection matrices in the self-attention layer.
         attention_dropout (float, defaults to 0.0):
             Dropout ratio for the attention probabilities.
+        gradient_checkpointing (bool, defaults to True): Whether to enable
+            gradient checkpointing (also known as activation checkpointing),
+            which trades compute for memory. This checkpoints every transformer
+            decoder layer. The memory savings are in the ballpark of
+            O(sqrt(n_layers)), while computation increases by ~30% (equivalent
+            to running the forward pass twice instead of once).
 
     """
     def __init__(
@@ -68,6 +75,7 @@ class OptimusConfig():
         attn_implementation="sdpa",
         attention_bias=False,
         attention_dropout=0.0,
+        gradient_checkpointing=True,
     ):
         self.vocab_size = vocab_size
         self.max_position_embeddings = max_position_embeddings
@@ -82,6 +90,7 @@ class OptimusConfig():
         self.attn_implementation = attn_implementation
         self.attention_bias = attention_bias
         self.attention_dropout = attention_dropout
+        self.gradient_checkpointing=gradient_checkpointing
         super().__init__()
 
     @classmethod
@@ -386,7 +395,15 @@ class OptimusTransformer(nn.Module):
 
         # go through all model layers
         for layer in self.layers:
-            x = layer(x, attn_mask) # (batch_size, seq_len, dim)
+            if self.config.gradient_checkpointing and self.training:
+                x = checkpoint(
+                    layer.__call__,
+                    x,
+                    attn_mask,
+                    use_reentrant=False,
+                ) # (batch_size, seq_len, dim)
+            else:
+                x = layer(x, attn_mask) # (batch_size, seq_len, dim)
 
         # final norm + linear layer
         x = self.output_norm(x)