diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py
index f053e17e48000b2b07aa86bb5cac660b53e45e15..116675bbf3bab2708391abd773166a6818051e33 100644
--- a/optimus/models/optimus.py
+++ b/optimus/models/optimus.py
@@ -41,6 +41,11 @@ class OptimusConfig():
         tie_word_embeddings (bool, defaults to False):
             Whether to tie word embeddings with the output matrix. Note: This
             has limited support with PyTorch FSDP!
+        attn_implementation (str, defaults to "sdpa"):
+            Attention implementation. Can be "eager" for the original
+            implementation, or "sdpa" for PyTorch's built-in Scaled Dot-Product
+            Attention, which implements eager mode, flash attention 2 and
+            memory-efficient attention in core.
         attention_bias (bool, defaults to False):
             Whether to use a bias term in the query, key, value and output
             projection matrices in the self-attention layer.
@@ -60,6 +65,7 @@ class OptimusConfig():
         initializer_range=0.02,
         rms_norm_eps=1e-6,
         tie_word_embeddings=False,
+        attn_implementation="sdpa",
         attention_bias=False,
         attention_dropout=0.0,
     ):
@@ -73,6 +79,7 @@ class OptimusConfig():
         self.initializer_range = initializer_range
         self.rms_norm_eps = rms_norm_eps
         self.tie_word_embeddings = tie_word_embeddings
+        self.attn_implementation = attn_implementation
         self.attention_bias = attention_bias
         self.attention_dropout = attention_dropout
         super().__init__()
@@ -208,6 +215,53 @@ class OptimusAttention(nn.Module):
         return output
 
 
+class OptimusSdpaAttention(OptimusAttention):
+    """
+    PyTorch native scaled dot-product attention. If GPU supports it, runs
+    Flash Attention 2. Inherits from normal 'OptimusAttention' as the weights
+    remain untouched. Only changes the forward pass.
+
+    """
+    def forward(self,
+                x: torch.Tensor,
+                mask: torch.Tensor | None = None) -> torch.Tensor:
+
+        bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
+
+        # linearly project the inputs
+        Q = self.q_proj(x) # (batch_size, seq_len, n_heads * head_dim)
+        K = self.k_proj(x)
+        V = self.v_proj(x)
+
+        # split into num_heads to compute attention
+        queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, head_dim)
+        keys = K.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+        values = V.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+
+        # compute attention through flash attention 2, if supported by the GPU
+        attn_output = F.scaled_dot_product_attention(
+            queries,
+            keys,
+            values,
+            dropout_p=self.attention_dropout,
+            attn_mask=None,
+            is_causal=mask is not None,
+        )
+
+        output = attn_output.permute(0, 2, 1, 3).contiguous()
+        output = output.view(bs, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, num_heads * head_dim)
+
+        # final linear
+        output = self.o_proj(output)
+        return output
+
+
+OPTIMUS_ATTENTION_CLASSES = {
+    "eager": OptimusAttention,
+    "sdpa": OptimusSdpaAttention,
+}
+
+
 class OptimusDecoderLayer(nn.Module):
     """
     OptimusTransformer decoder layer; contains a multi-head attention and a feed
@@ -219,7 +273,7 @@ class OptimusDecoderLayer(nn.Module):
     """
     def __init__(self, config: OptimusConfig, layer_idx: int):
         super().__init__()
-        self.self_attn = OptimusAttention(config, layer_idx)
+        self.self_attn = OPTIMUS_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
         self.ffn = OptimusFeedForward(config)
         self.attn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)
         self.ffn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)