Skip to content
Snippets Groups Projects
Unverified Commit 8247f4a4 authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Add gradient checkpointing option to Optimus

Gradient (or activation) checkpointing trades compute for memory saved.
This should overall make it easier to train large models on not-so-large
hardware.

Add checkpointing to every layer (same as HuggingFace), as opposed to
every 2/3 layers, since 1) this is the easiest to implement, and 2) has
the best balance between memory/compute.
parent 70ccb523
No related branches found
No related tags found
1 merge request!25Re-factor optimus-prime code (optimus-prime v2)
......@@ -4,6 +4,7 @@ import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
class OptimusConfig():
......@@ -51,6 +52,12 @@ class OptimusConfig():
projection matrices in the self-attention layer.
attention_dropout (float, defaults to 0.0):
Dropout ratio for the attention probabilities.
gradient_checkpointing (bool, defaults to True): Whether to enable
gradient checkpointing (also known as activation checkpointing),
which trades compute for memory. This checkpoints every transformer
decoder layer. The memory savings are in the ballpark of
O(sqrt(n_layers)), while computation increases by ~30% (equivalent
to running the forward pass twice instead of once).
"""
def __init__(
......@@ -68,6 +75,7 @@ class OptimusConfig():
attn_implementation="sdpa",
attention_bias=False,
attention_dropout=0.0,
gradient_checkpointing=True,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
......@@ -82,6 +90,7 @@ class OptimusConfig():
self.attn_implementation = attn_implementation
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.gradient_checkpointing=gradient_checkpointing
super().__init__()
@classmethod
......@@ -386,7 +395,15 @@ class OptimusTransformer(nn.Module):
# go through all model layers
for layer in self.layers:
x = layer(x, attn_mask) # (batch_size, seq_len, dim)
if self.config.gradient_checkpointing and self.training:
x = checkpoint(
layer.__call__,
x,
attn_mask,
use_reentrant=False,
) # (batch_size, seq_len, dim)
else:
x = layer(x, attn_mask) # (batch_size, seq_len, dim)
# final norm + linear layer
x = self.output_norm(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment