From 8247f4a47cab4e5aa0758677470b7e829da30215 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 3 Jun 2024 21:46:22 +0300
Subject: [PATCH] Add gradient checkpointing option to Optimus

Gradient (or activation) checkpointing trades compute for memory saved.
This should overall make it easier to train large models on not-so-large
hardware.

Add checkpointing to every layer (same as HuggingFace), as opposed to
every 2/3 layers, since 1) this is the easiest to implement, and 2) has
the best balance between memory/compute.
---
 optimus/models/optimus.py | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py
index 116675b..fc11a93 100644
--- a/optimus/models/optimus.py
+++ b/optimus/models/optimus.py
@@ -4,6 +4,7 @@ import json
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
 
 
 class OptimusConfig():
@@ -51,6 +52,12 @@ class OptimusConfig():
             projection matrices in the self-attention layer.
         attention_dropout (float, defaults to 0.0):
             Dropout ratio for the attention probabilities.
+        gradient_checkpointing (bool, defaults to True): Whether to enable
+            gradient checkpointing (also known as activation checkpointing),
+            which trades compute for memory. This checkpoints every transformer
+            decoder layer. The memory savings are in the ballpark of
+            O(sqrt(n_layers)), while computation increases by ~30% (equivalent
+            to running the forward pass twice instead of once).
 
     """
     def __init__(
@@ -68,6 +75,7 @@ class OptimusConfig():
         attn_implementation="sdpa",
         attention_bias=False,
         attention_dropout=0.0,
+        gradient_checkpointing=True,
     ):
         self.vocab_size = vocab_size
         self.max_position_embeddings = max_position_embeddings
@@ -82,6 +90,7 @@ class OptimusConfig():
         self.attn_implementation = attn_implementation
         self.attention_bias = attention_bias
         self.attention_dropout = attention_dropout
+        self.gradient_checkpointing=gradient_checkpointing
         super().__init__()
 
     @classmethod
@@ -386,7 +395,15 @@ class OptimusTransformer(nn.Module):
 
         # go through all model layers
         for layer in self.layers:
-            x = layer(x, attn_mask) # (batch_size, seq_len, dim)
+            if self.config.gradient_checkpointing and self.training:
+                x = checkpoint(
+                    layer.__call__,
+                    x,
+                    attn_mask,
+                    use_reentrant=False,
+                ) # (batch_size, seq_len, dim)
+            else:
+                x = layer(x, attn_mask) # (batch_size, seq_len, dim)
 
         # final norm + linear layer
         x = self.output_norm(x)
-- 
GitLab