From 70ccb52323df5d4bb6d07d5d23dd56429d4dec4a Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Mon, 3 Jun 2024 22:05:19 +0300 Subject: [PATCH] Add PyTorch built-in SDPA to Optimus Add PyTorch's core scaled dot-product attention (SDPA) to Optimus. This automatically uses flash attention 2, or memory efficient attention, if the hardware supports it. If it doesn't, falls back to manual implementation. Training should be much faster with this; memory should also be around half what it was before. --- optimus/models/optimus.py | 56 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py index f053e17..116675b 100644 --- a/optimus/models/optimus.py +++ b/optimus/models/optimus.py @@ -41,6 +41,11 @@ class OptimusConfig(): tie_word_embeddings (bool, defaults to False): Whether to tie word embeddings with the output matrix. Note: This has limited support with PyTorch FSDP! + attn_implementation (str, defaults to "sdpa"): + Attention implementation. Can be "eager" for the original + implementation, or "sdpa" for PyTorch's built-in Scaled Dot-Product + Attention, which implements eager mode, flash attention 2 and + memory-efficient attention in core. attention_bias (bool, defaults to False): Whether to use a bias term in the query, key, value and output projection matrices in the self-attention layer. @@ -60,6 +65,7 @@ class OptimusConfig(): initializer_range=0.02, rms_norm_eps=1e-6, tie_word_embeddings=False, + attn_implementation="sdpa", attention_bias=False, attention_dropout=0.0, ): @@ -73,6 +79,7 @@ class OptimusConfig(): self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.tie_word_embeddings = tie_word_embeddings + self.attn_implementation = attn_implementation self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__() @@ -208,6 +215,53 @@ class OptimusAttention(nn.Module): return output +class OptimusSdpaAttention(OptimusAttention): + """ + PyTorch native scaled dot-product attention. If GPU supports it, runs + Flash Attention 2. Inherits from normal 'OptimusAttention' as the weights + remain untouched. Only changes the forward pass. + + """ + def forward(self, + x: torch.Tensor, + mask: torch.Tensor | None = None) -> torch.Tensor: + + bs, seq_len, _ = x.shape # (batch_size, seq_len, dim) + + # linearly project the inputs + Q = self.q_proj(x) # (batch_size, seq_len, n_heads * head_dim) + K = self.k_proj(x) + V = self.v_proj(x) + + # split into num_heads to compute attention + queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, head_dim) + keys = K.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + values = V.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # compute attention through flash attention 2, if supported by the GPU + attn_output = F.scaled_dot_product_attention( + queries, + keys, + values, + dropout_p=self.attention_dropout, + attn_mask=None, + is_causal=mask is not None, + ) + + output = attn_output.permute(0, 2, 1, 3).contiguous() + output = output.view(bs, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, num_heads * head_dim) + + # final linear + output = self.o_proj(output) + return output + + +OPTIMUS_ATTENTION_CLASSES = { + "eager": OptimusAttention, + "sdpa": OptimusSdpaAttention, +} + + class OptimusDecoderLayer(nn.Module): """ OptimusTransformer decoder layer; contains a multi-head attention and a feed @@ -219,7 +273,7 @@ class OptimusDecoderLayer(nn.Module): """ def __init__(self, config: OptimusConfig, layer_idx: int): super().__init__() - self.self_attn = OptimusAttention(config, layer_idx) + self.self_attn = OPTIMUS_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx) self.ffn = OptimusFeedForward(config) self.attn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps) self.ffn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps) -- GitLab