From a91b0d2ab727ddf7fbb787490a68ed456374723d Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 3 Jun 2024 21:43:12 +0300
Subject: [PATCH] Move Optimus configuration into separate config class

This should be much nicer to work with, since every option / setting of
the model can be controlled through a dataclass; this config can also be
created easily from a json file.

Set a naming scheme for the Optimus model, similar to HuggingFace
models.
---
 inference.py              |  15 +-
 optimus/models/optimus.py | 359 +++++++++++++++++++++++---------------
 training.py               |  15 +-
 3 files changed, 236 insertions(+), 153 deletions(-)

diff --git a/inference.py b/inference.py
index 45a81ae..0f69c01 100644
--- a/inference.py
+++ b/inference.py
@@ -7,7 +7,7 @@ import torch
 from torch import Tensor
 
 from optimus.tokenizers import SentencePieceTokenizer
-from optimus.models import OptimusTransformer
+from optimus.models.optimus import OptimusTransformer, OptimusConfig
 
 
 def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
@@ -118,12 +118,13 @@ def main(model_path: str = 'model.pth',
         "erroneous generation!")
 
     # create model, load weights
-    model = OptimusTransformer(vocab_sz=vocab_sz,
-                               n_layers=n_layers,
-                               n_heads=n_heads,
-                               dim=dim,
-                               p_drop=p_drop,
-                               weight_tying=weight_tying)
+    config = OptimusConfig(vocab_size=vocab_sz,
+                           num_hidden_layers=n_layers,
+                           num_attention_heads=n_heads,
+                           hidden_size=dim,
+                           attention_dropout=p_drop,
+                           tie_word_embeddings=weight_tying)
+    model = OptimusTransformer(config)
     model.load_state_dict(state, strict=True)
     model.eval()
 
diff --git a/optimus/models/optimus.py b/optimus/models/optimus.py
index 36ce3b0..f053e17 100644
--- a/optimus/models/optimus.py
+++ b/optimus/models/optimus.py
@@ -1,55 +1,170 @@
 import math
+import json
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 
-class Norm(nn.Module):
-    def __init__(self, dim: int, eps: float = 1e-6):
-        """
-        Root-mean square layer normalization (https://arxiv.org/abs/1910.07467).
+class OptimusConfig():
+    """
+    Optimus model configuration class to store the configuration of an
+    OptimusTransformer. Left on defaults, will instantiate an OptimusTransformer
+    similar in size to GPT2 (~1.5B parameters).
+
+    Args:
+        vocab_size (int, defaults to 50257):
+            Vocabulary size of the model. Affects the input embeddings and
+            output matrix. Make sure to use with a proper tokenizer with the
+            correct vocab size.
+        hidden_size (int, defaults to 1600):
+            Dimension of hidden representations.
+        intermediate_size (int, defaults to 4352):
+            Dimension of the MLP (feed forward layer) representations. This
+            should be equal to 8 * hidden_size / 3, rounded up to a multiple of
+            256.
+        num_hidden_layers (int, defaults to 48):
+            Number of hidden layers in the OptimusTransformer decoder.
+        num_attention_heads (int, defaults to 40):
+            Number of attention heads for each attentian layer in the
+            OptimusTransformer decoder.
+        hidden_act (str, defaults to "silu"):
+            The non-linear activation function used in the MLP (feed forward)
+            layers of the transformer.
+        max_position_embeddings (int, defaults to 2048):
+            The maximum sequence length that the model might be used for. This
+            is used to pre-compute the sinusoidal position embeddings.
+        initializer_range (float, defaults to 0.02):
+            Standard deviation for initializing weight matrices.
+        rms_norm_eps (float, defaults to 1e-6):
+            Epsilon value used by the RMS normalization layers.
+        tie_word_embeddings (bool, defaults to False):
+            Whether to tie word embeddings with the output matrix. Note: This
+            has limited support with PyTorch FSDP!
+        attention_bias (bool, defaults to False):
+            Whether to use a bias term in the query, key, value and output
+            projection matrices in the self-attention layer.
+        attention_dropout (float, defaults to 0.0):
+            Dropout ratio for the attention probabilities.
+
+    """
+    def __init__(
+        self,
+        vocab_size=50257,
+        hidden_size=1600,
+        intermediate_size=4352,
+        num_hidden_layers=48,
+        num_attention_heads=40,
+        hidden_act="silu",
+        max_position_embeddings=2048,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        tie_word_embeddings=False,
+        attention_bias=False,
+        attention_dropout=0.0,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.tie_word_embeddings = tie_word_embeddings
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        super().__init__()
 
-        Args:
-            See Transformer.
+    @classmethod
+    def from_json_file(cls, file_path):
+        with open(file_path, 'r') as file:
+            return cls(**json.load(file))
 
-        """
+
+class OptimusRMSNorm(nn.Module):
+    """
+    Root-mean square layer normalization (https://arxiv.org/abs/1910.07467).
+
+    Args:
+        See OptimusTransformer.
+
+    """
+    def __init__(self, hidden_size: int, eps: float = 1e-6):
         super().__init__()
         self.eps = eps
-        self.weight = nn.Parameter(torch.ones(dim))
+        self.weight = nn.Parameter(torch.ones(hidden_size))
 
     def _norm(self, x):
-        std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
-        return x / (std + self.eps)
+        variance = x.pow(2).mean(dim=-1, keepdim=True)
+        return x * torch.rsqrt(variance + self.eps)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         # compute in float32, not in fp16, since normalization needs to be accurate
         output = self._norm(x.float()).type_as(x)
-        return output * self.weight
+        return self.weight * output
 
 
-class Attention(nn.Module):
-    def __init__(self, n_heads: int, dim: int, p_drop: float):
-        """
-        Multi-head attention layer.
+class OptimusFeedForward(nn.Module):
+    """
+    Feed forward (MLP) layer, similar to LLama2's implementation of MLP.
 
-        Args:
-            See Transformer.
+    This is a modified version of the original feed forward layer from
+    "Attention Is All You Need". The changes are described in "GLU variants
+    improve transformer" (https://arxiv.org/abs/2002.05202).
 
-        """
+    A summary of the changes proposed:
+    - 3 linear layers instead of 2
+    - reduce dimension of the 3 matrices to 2/3 * dim to keep computation
+      constant
+    - change the activation function to SwiGLU (Swish)
+
+    Args:
+        See OptimusTransformer.
+
+    """
+    def __init__(self, config: OptimusConfig):
         super().__init__()
+        hidden_size = config.hidden_size
+        intermediate_size = config.intermediate_size
 
-        assert dim % n_heads == 0
-        self.head_dim = dim // n_heads
-        self.heads = n_heads
-        self.scale = 1 / math.sqrt(self.head_dim)
+        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+
+        if config.hidden_act == 'silu' or config.hidden_act == 'swish':
+            self.act_fn = F.silu
+        else:
+            raise KeyError(f"Currently only silu and swish are supported as activation functions, but got {config.hidden_act}")
 
-        self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
-        self.wk = nn.Linear(dim, n_heads * self.head_dim, bias=False)
-        self.wv = nn.Linear(dim, n_heads * self.head_dim, bias=False)
-        self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
+    def forward(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class OptimusAttention(nn.Module):
+    """
+    Multi-head attention layer.
+
+    Args:
+        See OptimusTransformer.
+
+    """
+    def __init__(self, config: OptimusConfig, layer_idx: int):
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+
+        assert self.hidden_size % self.num_heads == 0
+        self.scale = 1 / math.sqrt(self.head_dim)
 
-        self.dropout = nn.Dropout(p=p_drop)
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
 
     def forward(self,
                 x: torch.Tensor,
@@ -58,14 +173,14 @@ class Attention(nn.Module):
         bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
 
         # linearly project the inputs
-        Q = self.wq(x) # (batch_size, seq_len, n_heads * head_dim)
-        K = self.wk(x)
-        V = self.wv(x)
+        Q = self.q_proj(x) # (batch_size, seq_len, n_heads * head_dim)
+        K = self.k_proj(x)
+        V = self.v_proj(x)
 
-        # split into n_heads to compute attention
-        queries = Q.view(bs, seq_len, self.heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, head_dim)
-        keys = K.view(bs, seq_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
-        values = V.view(bs, seq_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
+        # split into num_heads to compute attention
+        queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, head_dim)
+        keys = K.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+        values = V.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 
         # compute attention matmul
         keys = keys.permute(0, 1, 3, 2) # (batch_size, n_heads, head_dim, seq_len)
@@ -78,133 +193,95 @@ class Attention(nn.Module):
         if mask is not None:
             scores = scores + mask
 
-        # softmax (attention probabilities) + dropout
-        attn_probs = self.dropout(F.softmax(scores, dim=-1))
+        # softmax (attention probabilities) in float32 + dropout
+        attn_probs = F.softmax(scores, dim=-1, dtype=torch.float32).to(dtype=Q.dtype)
+        attn_probs = F.dropout(attn_probs, p=self.attention_dropout, training=self.training)
 
         # matmul
         output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)
 
         output = output.permute(0, 2, 1, 3).contiguous()
-        output = output.view(bs, seq_len, self.heads * self.head_dim) # (batch_size, seq_len, n_heads * head_dim)
+        output = output.view(bs, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, num_heads * head_dim)
 
         # final linear
-        output = self.wo(output)
+        output = self.o_proj(output)
         return output
 
 
-class FeedForward(nn.Module):
-    def __init__(self, dim: int, ffn_dim: int):
-        """
-        Feed forward layer.
-
-        This is a modified version of the original feed forward layer from
-        "Attention Is All You Need". The changes are described in "GLU variants
-        improve transformer" (https://arxiv.org/abs/2002.05202).
-
-        A summary of the changes proposed:
-        - 3 linear layers instead of 2
-        - reduce dimension of the 3 matrices to 2/3 * dim to keep computation
-          constant
-        - change the activation function to SwiGLU (Swish)
-
-        Args:
-            See Transformer.
-
-        """
-        super().__init__()
-
-        ffn_dim = int(2 * ffn_dim / 3)
-
-        self.fn1 = nn.Linear(dim, ffn_dim, bias=False)
-        self.fn2 = nn.Linear(ffn_dim, dim, bias=False)
-        self.fn3 = nn.Linear(dim, ffn_dim, bias=False)
-
-    def forward(self, x):
-        return self.fn2(F.silu(self.fn1(x)) * self.fn3(x))
-
-
-class TransformerBlock(nn.Module):
-    def __init__(self, n_heads: int, dim: int, p_drop: float):
-        """
-        A transformer layer, contains a multi-head attention and a feed forward.
+class OptimusDecoderLayer(nn.Module):
+    """
+    OptimusTransformer decoder layer; contains a multi-head attention and a feed
+    forward (MLP) layer.
 
-        Args:
-            See Transformer.
+    Args:
+        See OptimusTransformer.
 
-        """
+    """
+    def __init__(self, config: OptimusConfig, layer_idx: int):
         super().__init__()
-        self.attention = Attention(n_heads, dim, p_drop)
-        self.ffn = FeedForward(dim, 4 * dim)
-        self.attention_norm = Norm(dim)
-        self.ffn_norm = Norm(dim)
-        self.dropout = nn.Dropout(p_drop)
+        self.self_attn = OptimusAttention(config, layer_idx)
+        self.ffn = OptimusFeedForward(config)
+        self.attn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)
+        self.ffn_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self,
+                x: torch.Tensor,
+                attn_mask: torch.Tensor,
+                ) -> torch.Tensor:
         """
         Compute a layer of the model (attention + ffn).
         """
-        _, seq_len, _ = x.shape # (bs, seq_len, dim)
-
-        # compute mask for masked self-attention
-        mask = torch.full((1, seq_len, seq_len), float("-inf"))
-        mask = torch.triu(mask, diagonal=1).to(x.device)
-
         # compute normed attention and normed feed forward
-        x = x + self.dropout(self.attention(self.attention_norm(x), mask))
-        x = x + self.dropout(self.ffn(self.ffn_norm(x)))
+        x = x + self.self_attn(self.attn_norm(x), attn_mask)
+        x = x + self.ffn(self.ffn_norm(x))
         return x
 
 
 class OptimusTransformer(nn.Module):
-    def __init__(self,
-                 vocab_sz: int,
-                 dim: int = 512,
-                 n_layers: int = 6,
-                 n_heads: int = 8,
-                 p_drop: float = 0.0,
-                 weight_tying: bool = False):
-        """
-        A transformer implementation. Most of it is derived directly from
-        "Attention is All You Need" (see https://arxiv.org/abs/1706.03762),
-        with additional changes from recent research.
+    """
+    A plain PyTorch decoder-only GPT-style transformer implementation. Most of
+    this model directly copies LLama2's implementation, with small changes and
+    adjustments for experimenting and playing around.
+
+    Args:
+        config (OptimusConfig): An OptimusTransformer configuration used to
+            instantiate the model. With defaults, instantiates a model similar
+            to GPT2 XL in size (1.5B).
+
+    """
+    def __init__(self, config: OptimusConfig):
+        super().__init__()
+        self.config = config
+        self.vocab_size = config.vocab_size
 
-        Args:
-            vocab_sz (int): Vocabulary size.
-            dim (int): The dimension of embeddings in the model.
-            n_layers (int): The number of layers in the model.
-            n_heads (int): The number of attention heads.
-            p_drop (float): Dropout probability. Dropout is applied to input
-                embeddings, the outputs of attention layers, and the outputs of
-                feed forward layers.
-            weight_tying (bool): Whether to use weight-tying. Intuitively,
-                weight-tying could be useful, because the set of mappings from
-                the input tokens to the embeddings should be the same as that
-                from the embeddings to the output tokens (because the vocabulary
-                is the same for both encoding and decoding).
+        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = nn.ModuleList(
+            [OptimusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
 
-        """
-        super().__init__()
+        self.positional_encodings = self._compute_freqs(config.hidden_size)
+
+        self.output = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
+        self.output_norm = OptimusRMSNorm(config.hidden_size, config.rms_norm_eps)
 
-        # add as buffers so they get saved along with the model with torch.save()
-        self.register_buffer('vocab_sz', torch.tensor(vocab_sz))
-        self.register_buffer('n_layers', torch.tensor(n_layers))
-        self.register_buffer('n_heads', torch.tensor(n_heads))
-        self.register_buffer('dim', torch.tensor(dim))
-        self.register_buffer('p_drop', torch.tensor(p_drop))
-        self.register_buffer('weight_tying', torch.tensor(weight_tying))
-
-        self.embeddings = nn.Embedding(vocab_sz, dim)
-        self.positional_encodings = self._compute_freqs(dim)
-        self.input_dropout = nn.Dropout(p=p_drop)
-        self.layers = nn.ModuleList([TransformerBlock(n_heads, dim, p_drop) for _ in range(n_layers)])
-        self.output = nn.Linear(dim, vocab_sz, bias=False)
-        self.output_norm = Norm(dim)
-
-        if weight_tying:
+        # ! careful, this has limited FSDP support
+        if config.tie_word_embeddings:
             self.output.weight = self.embeddings.weight
 
+        self.apply(self._init_weights)
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
     def _compute_freqs(self, dim: int,
-                       max_seq_len: int = 4096,
                        theta: float = 10000.0) -> torch.Tensor:
         """
         Precompute a frequency matrix to apply to input embeddings (positional
@@ -212,15 +289,13 @@ class OptimusTransformer(nn.Module):
 
         Args:
             dim (int): Dimension of embeddings in the model.
-            max_seq_len (int): The maximum context length to precompute
-                frequencies for. Defaults to 4096. Needs to be increased for
-                bigger context lengths.
             theta (float): Frequency.
 
         Returns:
             torch.Tensor: Positional encodings tensor.
 
         """
+        max_seq_len = self.config.max_position_embeddings
         pos = torch.arange(0, max_seq_len).unsqueeze(1)
         i = torch.arange(0, dim, 2)
         div_term = torch.exp(i * (-math.log(theta) / dim))
@@ -244,14 +319,20 @@ class OptimusTransformer(nn.Module):
             torch.Tensor: Output logits of the model.
 
         """
+        _, seq_len = x.shape # (bs, seq_len)
+
         # input embeddings and positional encoding
         x = self.embeddings(x) # (batch_size, seq_len, dim)
-        encs = self.positional_encodings[:,: x.shape[1]].to(x.device)
-        x = self.input_dropout(x + encs)
+        encs = self.positional_encodings[:,:seq_len].to(x.device)
+        x = x + encs
+
+        # compute mask for masked self-attention
+        attn_mask = torch.full((1, seq_len, seq_len), torch.finfo(x.dtype).min)
+        attn_mask = torch.triu(attn_mask, diagonal=1).to(dtype=x.dtype, device=x.device)
 
         # go through all model layers
         for layer in self.layers:
-            x = layer(x) # (batch_size, seq_len, dim)
+            x = layer(x, attn_mask) # (batch_size, seq_len, dim)
 
         # final norm + linear layer
         x = self.output_norm(x)
diff --git a/training.py b/training.py
index 64cdf8f..897a285 100644
--- a/training.py
+++ b/training.py
@@ -5,7 +5,7 @@ from torch import nn
 from optimus.datasets import WikiText103Dataset
 from optimus.tokenizers import SentencePieceTokenizer
 from optimus.dataloader import OptimusDataLoader
-from optimus.models import OptimusTransformer
+from optimus.models.optimus import OptimusTransformer, OptimusConfig
 from optimus.trainer import Trainer
 
 
@@ -78,12 +78,13 @@ def main(batch_size: int = 8,
                            device='cuda')
 
     # create model and move to device
-    model = OptimusTransformer(len(tok),
-                               n_layers=n_layers,
-                               dim=dim,
-                               n_heads=n_heads,
-                               p_drop=dropout,
-                               weight_tying=False)
+    config = OptimusConfig(vocab_size=len(tok),
+                           num_hidden_layers=n_layers,
+                           num_attention_heads=n_heads,
+                           hidden_size=dim,
+                           attention_dropout=dropout,
+                           tie_word_embeddings=False)
+    model = OptimusTransformer(config)
     model = model.to('cuda')
 
     _total_params = sum(p.numel() for p in model.parameters())
-- 
GitLab