diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py index f053e17e48000b2b07aa86bb5cac660b53e45e15..116675bbf3bab2708391abd773166a6818051e33 100644 --- a/optimus/models/optimus.py +++ b/optimus/models/optimus.py @@ -41,6 +41,11 @@ class OptimusConfig(): tie_word_embeddings (bool, defaults to False): Whether to tie word embeddings with the output matrix. Note: This has limited support with PyTorch FSDP! + attn_implementation (str, defaults to "sdpa"): + Attention implementation. Can be "eager" for the original + implementation, or "sdpa" for PyTorch's built-in Scaled Dot-Product + Attention, which implements eager mode, flash attention 2 and + memory-efficient attention in core. attention_bias (bool, defaults to False): Whether to use a bias term in the query, key, value and output projection matrices in the self-attention layer. @@ -60,6 +65,7 @@ class OptimusConfig(): initializer_range=0.02, rms_norm_eps=1e-6, tie_word_embeddings=False, + attn_implementation="sdpa", attention_bias=False, attention_dropout=0.0, ): @@ -73,6 +79,7 @@ class OptimusConfig(): self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.tie_word_embeddings = tie_word_embeddings + self.attn_implementation = attn_implementation self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__() @@ -208,6 +215,53 @@ class OptimusAttention(nn.Module): return output +class OptimusSdpaAttention(OptimusAttention): + """ + PyTorch native scaled dot-product attention. If GPU supports it, runs + Flash Attention 2. Inherits from normal 'OptimusAttention' as the weights + remain untouched. Only changes the forward pass. + + """ + def forward(self, + x: torch.Tensor, + mask: torch.Tensor | None = None) -> torch.Tensor: + + bs, seq_len, _ = x.shape # (batch_size, seq_len, dim) + + # linearly project the inputs + Q = self.q_proj(x) # (batch_size, seq_len, n_heads * head_dim) + K = self.k_proj(x) + V = self.v_proj(x) + + # split into num_heads to compute attention + queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, head_dim) + keys = K.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + values = V.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # compute attention through flash attention 2, if supported by the GPU + attn_output = F.scaled_dot_product_attention( + queries, + keys, + values, + dropout_p=self.attention_dropout, + attn_mask=None, + is_causal=mask is not None, + ) + + output = attn_output.permute(0, 2, 1, 3).contiguous() + output = output.view(bs, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, num_heads * head_dim) + + # final linear + output = self.o_proj(output) + return output + + +OPTIMUS_ATTENTION_CLASSES = { + "eager": OptimusAttention, + "sdpa": OptimusSdpaAttention, +} + + class OptimusDecoderLayer(nn.Module): """ OptimusTransformer decoder layer; contains a multi-head attention and a feed @@ -219,7 +273,7 @@ class OptimusDecoderLayer(nn.Module): """ def __init__(self, config: OptimusConfig, layer_idx: int): super().__init__() - self.self_attn = OptimusAttention(config, layer_idx) + self.self_attn = OPTIMUS_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx) self.ffn = OptimusFeedForward(config) self.attn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps) self.ffn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)