From accab39c8bd0ebd0e136bad320a81b0a3fb48f84 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Thu, 4 Jan 2024 13:52:16 +0200
Subject: [PATCH] Add inference code

Inference example code. At the moment, the code simply loads a model
state file and generates text using that. Parameters like max sequence
length, whether training used fp16, what the tokenizer used for training
is etc., need to be passed manually by the user (there's a lot of room
for error here). To be improved.

Merges changes from !14
Closes !14
---
 inference.py | 179 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 179 insertions(+)
 create mode 100644 inference.py

diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..c0939cd
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,179 @@
+import os
+import warnings
+
+import fire
+import torch
+from torch import Tensor
+
+from optimus.tokenizers import SentencePieceTokenizer
+from optimus.models import OptimusTransformer
+
+
+def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
+    """
+    Nucleus (top-p) sampling.
+
+    Args:
+        probs (Tensor): Output probability distribution of the model. This
+            should be the softmax'ed logits.
+        top_p (float): Top-p threshold value for sampling.
+
+    Returns:
+        Tensor: Sampled output token, as index in the vocabulary.
+
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > top_p
+    probs_sort[mask] = 0.
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    token = torch.multinomial(probs_sort, num_samples=1)
+    token = torch.gather(probs_idx, -1, token)
+    return token
+
+
+def main(model_path: str = 'model.pth',
+         tokenizer_path: str = 'optimus.model',
+         max_seq_len: int = 512,
+         temperature: float = 0.6,
+         top_p: float = 0.9,
+         use_fp16: bool = True,
+         device: str = 'cuda'):
+    """
+    Run the main inference loop for the model.
+
+    Args:
+        model_path (str): Path to the saved model.
+        tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
+            tokenizer is used for both training and testing, otherwise the
+            results might be undefined.
+        max_seq_len (int): Maximum context length of the model. Make sure the
+            value is similar to 'seq_len' used when training, otherwise the
+            model will generalize poorly to higher context lengths.
+        temperature (float): Temperature value to control randomness in
+            sampling.
+        top_p (float): Probability threshold used in nucleus sampling.
+        use_fp16 (bool): Whether the computation should be in fp16 or
+            full-precision fp32. If the model weights are stored as fp16, they
+            will be converted to fp32 and vice-versa. If 'use_fp16' is enabled,
+            but the model was trained with fp32, performance might drop
+            significantly.
+        device (str): Device where to run inference. Viable options are 'cpu',
+            'cuda', 'cuda:2' etc.
+
+    """
+
+    print(f"Running with:\n"
+        f"\t- model path: '{model_path}'\n"
+        f"\t- tokenizer path: '{tokenizer_path}'\n"
+        f"\t- max context length: {max_seq_len}\n"
+        f"\t- temperature: {temperature}\n"
+        f"\t- top_p threshold: {top_p}\n"
+        f"\t- 16-bit floating-point: {use_fp16}\n"
+        f"\t- running inference on device: {device}\n"
+        f"Please see '--help' if you want to change these settings")
+
+    # load tokenizer
+    tok = SentencePieceTokenizer(model_path=tokenizer_path)
+
+    # set default tensor type for things like positional encodings
+    if use_fp16 is True:
+        if device == 'cuda':
+            # avoid warnings about set_default_tensor_type() being deprecated
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.cuda.HalfTensor)
+        elif device == 'cpu':
+            assert 0 == 1, "Cannot run 16-bit inference on CPU!"
+    else:
+        if device == 'cuda':
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.cuda.FloatTensor)
+        elif device == 'cpu':
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.FloatTensor)
+
+    print("Loading model from disk...")
+
+    # load state from file
+    assert os.path.exists(model_path)
+    state = torch.load(model_path, map_location=device)
+
+    # these will be loaded into the model automatically later anyway, but we
+    # need at least `dim` now to be able to compute positional encodings
+    # correctly, and `vocab_sz` to make sure the tokenizer works
+    vocab_sz = int(state['vocab_sz'])
+    n_layers = int(state['n_layers'])
+    n_heads = int(state['n_heads'])
+    dim = int(state['dim'])
+    p_drop = float(state['p_drop'])
+    weight_tying = bool(state['weight_tying'])
+
+    assert vocab_sz == len(tok), ("The tokenizer passed for inference is "
+        "different from the tokenizer used for training! This will result in "
+        "erroneous generation!")
+
+    # create model, load weights
+    model = OptimusTransformer(vocab_sz=vocab_sz,
+                               n_layers=n_layers,
+                               n_heads=n_heads,
+                               dim=dim,
+                               p_drop=p_drop,
+                               weight_tying=weight_tying)
+    model.load_state_dict(state, strict=True)
+    model.eval()
+
+    print(f"Loaded model on device {device}!")
+
+    _total_params = sum(p.numel() for p in model.parameters())
+    print(f"Number of model parameters: {_total_params}")
+
+    # inference loop
+    print("Starting inference...")
+
+    print("Done! Waiting for user input... (prompt to complete)")
+
+    input_sentence = input("User: ")
+
+    # tokenize input
+    inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
+                       dtype=torch.long)
+    inp.unsqueeze_(0) # (1, seq_len)
+
+    seq_len = inp.shape[-1]
+
+    with torch.inference_mode():
+
+        while seq_len != max_seq_len:
+
+            # get model output
+            output = model(inp) # (1, seq_len, vocab_sz)
+
+            # get the logits for the last token
+            logits = output[0,-1,:] # (vocab_sz)
+
+            # sample the most relevant token from the probability
+            if temperature > 0:
+                # if temperature is enabled, nucleus (top-p) sampling
+                probs = torch.softmax(logits / temperature, dim=-1)
+                token = sample_top_p(probs, top_p)
+            else:
+                # otherwise, fallback to greedy decoding
+                token = torch.argmax(logits, dim=-1)
+                token.unsqueeze_(0)
+
+            # check if token is EOS
+            if token == tok.eos_id:
+                break
+
+            # append the token to the input
+            inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1)
+
+            seq_len = inp.shape[-1]
+
+    print("Model output: ", ' '.join(tok.decode(inp.tolist())))
+
+    print("Finished inference!")
+
+
+if __name__=='__main__':
+    fire.Fire(main)
-- 
GitLab