diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py index 116675bbf3bab2708391abd773166a6818051e33..fc11a93fafa6c8fb6e3a1f00281ca05b2b352983 100644 --- a/optimus/models/optimus.py +++ b/optimus/models/optimus.py @@ -4,6 +4,7 @@ import json import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint class OptimusConfig(): @@ -51,6 +52,12 @@ class OptimusConfig(): projection matrices in the self-attention layer. attention_dropout (float, defaults to 0.0): Dropout ratio for the attention probabilities. + gradient_checkpointing (bool, defaults to True): Whether to enable + gradient checkpointing (also known as activation checkpointing), + which trades compute for memory. This checkpoints every transformer + decoder layer. The memory savings are in the ballpark of + O(sqrt(n_layers)), while computation increases by ~30% (equivalent + to running the forward pass twice instead of once). """ def __init__( @@ -68,6 +75,7 @@ class OptimusConfig(): attn_implementation="sdpa", attention_bias=False, attention_dropout=0.0, + gradient_checkpointing=True, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -82,6 +90,7 @@ class OptimusConfig(): self.attn_implementation = attn_implementation self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.gradient_checkpointing=gradient_checkpointing super().__init__() @classmethod @@ -386,7 +395,15 @@ class OptimusTransformer(nn.Module): # go through all model layers for layer in self.layers: - x = layer(x, attn_mask) # (batch_size, seq_len, dim) + if self.config.gradient_checkpointing and self.training: + x = checkpoint( + layer.__call__, + x, + attn_mask, + use_reentrant=False, + ) # (batch_size, seq_len, dim) + else: + x = layer(x, attn_mask) # (batch_size, seq_len, dim) # final norm + linear layer x = self.output_norm(x)