From 8247f4a47cab4e5aa0758677470b7e829da30215 Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Mon, 3 Jun 2024 21:46:22 +0300 Subject: [PATCH] Add gradient checkpointing option to Optimus Gradient (or activation) checkpointing trades compute for memory saved. This should overall make it easier to train large models on not-so-large hardware. Add checkpointing to every layer (same as HuggingFace), as opposed to every 2/3 layers, since 1) this is the easiest to implement, and 2) has the best balance between memory/compute. --- optimus/models/optimus.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py index 116675b..fc11a93 100644 --- a/optimus/models/optimus.py +++ b/optimus/models/optimus.py @@ -4,6 +4,7 @@ import json import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint class OptimusConfig(): @@ -51,6 +52,12 @@ class OptimusConfig(): projection matrices in the self-attention layer. attention_dropout (float, defaults to 0.0): Dropout ratio for the attention probabilities. + gradient_checkpointing (bool, defaults to True): Whether to enable + gradient checkpointing (also known as activation checkpointing), + which trades compute for memory. This checkpoints every transformer + decoder layer. The memory savings are in the ballpark of + O(sqrt(n_layers)), while computation increases by ~30% (equivalent + to running the forward pass twice instead of once). """ def __init__( @@ -68,6 +75,7 @@ class OptimusConfig(): attn_implementation="sdpa", attention_bias=False, attention_dropout=0.0, + gradient_checkpointing=True, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -82,6 +90,7 @@ class OptimusConfig(): self.attn_implementation = attn_implementation self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.gradient_checkpointing=gradient_checkpointing super().__init__() @classmethod @@ -386,7 +395,15 @@ class OptimusTransformer(nn.Module): # go through all model layers for layer in self.layers: - x = layer(x, attn_mask) # (batch_size, seq_len, dim) + if self.config.gradient_checkpointing and self.training: + x = checkpoint( + layer.__call__, + x, + attn_mask, + use_reentrant=False, + ) # (batch_size, seq_len, dim) + else: + x = layer(x, attn_mask) # (batch_size, seq_len, dim) # final norm + linear layer x = self.output_norm(x) -- GitLab