diff --git a/optimus/models/__init__.py b/optimus/models/__init__.py index ce35655a73bbb58ebef9b2c672180e023763e737..d38f83611e34610eacba0f3de7b07344e9e7ed73 100644 --- a/optimus/models/__init__.py +++ b/optimus/models/__init__.py @@ -1 +1,2 @@ from .optimus import OptimusTransformer +from .pytorch_transformer import PyTorchTransformer diff --git a/optimus/models/pytorch_transformer.py b/optimus/models/pytorch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc81b4bde96e70b585a145a0ad560e8543cd954 --- /dev/null +++ b/optimus/models/pytorch_transformer.py @@ -0,0 +1,137 @@ +import math + +import torch.nn as nn +import torch +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class Norm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Root-mean square layer normalization (https://arxiv.org/abs/1910.07467). + + Args: + See PytorchTransformer. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + std = x.pow(2).mean(dim=-1, keepdim=True).sqrt() + return x / (std + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # compute in float32, not in fp16, since normalization needs to be accurate + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class PyTorchTransformer(nn.Module): + + def __init__(self, + vocab_sz: int, + dim: int = 512, + n_layers: int = 6, + n_heads: int = 8, + p_drop: float = 0.0, + weight_tying: bool = False): + """ + Transformer implementation from PyTorch. + + Args: + vocab_sz (int): Vocabulary size. + dim (int): The dimension of embeddings in the model. + n_layers (int): The number of layers in the model. + n_heads (int): The number of attention heads. + p_drop (float): Dropout probability. Dropout is applied to input + embeddings, the outputs of attention layers, and the outputs of + feed forward layers. + weight_tying (bool): Whether to use weight-tying. Intuitively, + weight-tying could be useful, because the set of mappings from + the input tokens to the embeddings should be the same as that + from the embeddings to the output tokens (because the vocabulary + is the same for both encoding and decoding). + + """ + super().__init__() + + # add as buffers so they get saved along with the model with torch.save() + self.register_buffer('vocab_sz', torch.tensor(vocab_sz)) + self.register_buffer('n_layers', torch.tensor(n_layers)) + self.register_buffer('n_heads', torch.tensor(n_heads)) + self.register_buffer('dim', torch.tensor(dim)) + self.register_buffer('p_drop', torch.tensor(p_drop)) + self.register_buffer('weight_tying', torch.tensor(weight_tying)) + + self.embeddings = nn.Embedding(vocab_sz, dim) + self.input_dropout = nn.Dropout(p=p_drop) + self.positional_encodings = self._compute_freqs(dim) + + decoder_layers = TransformerEncoderLayer(dim, n_heads, 4 * dim, dropout=p_drop, batch_first=True, norm_first=True, bias=True) + self.transformer_decoder = TransformerEncoder(decoder_layers, n_layers) + + self.output = nn.Linear(dim, vocab_sz, bias=False) + self.output_norm = Norm(dim) + + if weight_tying: + self.output.weight = self.embeddings.weight + + def _compute_freqs(self, dim: int, + max_seq_len: int = 4096, + theta: float = 10000.0) -> torch.Tensor: + """ + Precompute a frequency matrix to apply to input embeddings (positional + encoding). + + Args: + dim (int): Dimension of embeddings in the model. + max_seq_len (int): The maximum context length to precompute + frequencies for. Defaults to 4096. Needs to be increased for + bigger context lengths. + theta (float): Frequency. + + Returns: + torch.Tensor: Positional encodings tensor. + + """ + pos = torch.arange(0, max_seq_len).unsqueeze(1) + i = torch.arange(0, dim, 2) + div_term = torch.exp(i * (-math.log(theta) / dim)) + + enc = torch.zeros(max_seq_len, dim) + enc[:, 0::2] = torch.sin(pos * div_term) + enc[:, 1::2] = torch.cos(pos * div_term) + enc = enc.unsqueeze(0) + + return enc + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the model to the input tokens. + + Args: + x (torch.Tensor): Input token indices. Expected in the shape + (batch_size, seq_len). + + Returns: + torch.Tensor: Output logits of the model. + + """ + # input embeddings and positional encoding + x = self.embeddings(x) # (batch_size, seq_len, dim) + encs = self.positional_encodings[:,: x.shape[1]].to(x.device) + x = self.input_dropout(x + encs) + + _, seq_len, _ = x.shape # (bs, seq_len, dim) + + # compute mask for masked self-attention + mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device) + + output = self.transformer_decoder(x, mask, is_causal=True) + + x = self.output_norm(x) + output = self.output(output) + + return output \ No newline at end of file