From cb1a7974c1cbff0433662ad52f2fce5077f30c3d Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 10 Jun 2024 16:05:34 +0300
Subject: [PATCH] Fix inference code

This should now work with any PyTorch model (Optimus is the example
given in the source code), as well as any HuggingFace model (adjusted
the code to be independent of any model source).
---
 inference.py | 76 +++++++++++++++++++++++++---------------------------
 1 file changed, 36 insertions(+), 40 deletions(-)

diff --git a/inference.py b/inference.py
index 99426b7..cfa5823 100644
--- a/inference.py
+++ b/inference.py
@@ -1,7 +1,7 @@
-import os
 import time
 import warnings
 import logging
+from pathlib import Path
 
 import fire
 import torch
@@ -37,19 +37,28 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
     return token
 
 
-def main(model_path: str = 'model.pth',
-         hf_tokenizer_name: str = 'gpt2',
-         prompt: str | None = None,
-         max_seq_len: int = 512,
-         temperature: float = 0.6,
-         top_p: float = 0.9,
-         use_fp16: bool = True,
-         device: str = 'cuda'):
+def main(
+    # model + tokenizer
+    model_config_path: Path = Path('config.json'),
+    weights_path: Path = Path('weights.pth'),
+    hf_tokenizer_name: str = 'gpt2',
+
+    # prompt
+    prompt: str | None = None,
+
+    # inference args
+    max_seq_len: int = 512,
+    temperature: float = 0.6,
+    top_p: float = 0.9,
+    use_fp16: bool = True,
+    device: str = 'cuda'
+):
     """
     Run the main inference loop for the model.
 
     Args:
-        model_path (str): Path to the saved model.
+        model_config_path (Path): Path to config for model.
+        weights_path (Path): Path to trained model weights.
         hf_tokenizer_name (str): HuggingFace tokenizer. Note: Ensure the same
             tokenizer is used for both training and testing, otherwise the
             results might be undefined.
@@ -72,7 +81,8 @@ def main(model_path: str = 'model.pth',
     """
 
     logger.info(f"Running with:\n"
-        f"\t- model path: '{model_path}'\n"
+        f"\t- model config path: '{model_config_path}'\n"
+        f"\t- weights file: '{weights_path}'\n"
         f"\t- huggingface tokenizer: '{hf_tokenizer_name}'\n"
         f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
         f"\t- max context length: {max_seq_len}\n"
@@ -103,34 +113,19 @@ def main(model_path: str = 'model.pth',
 
     logger.info('Loading model from disk...')
 
-    # load state from file
-    assert os.path.exists(model_path)
-    state = torch.load(model_path, map_location=device)
-
-    # these will be loaded into the model automatically later anyway, but we
-    # need at least `dim` now to be able to compute positional encodings
-    # correctly, and `vocab_sz` to make sure the tokenizer works
-    vocab_sz = int(state['vocab_sz'])
-    n_layers = int(state['n_layers'])
-    n_heads = int(state['n_heads'])
-    dim = int(state['dim'])
-    p_drop = float(state['p_drop'])
-    weight_tying = bool(state['weight_tying'])
-
-    assert vocab_sz == len(tokenizer), ('The tokenizer passed for inference is '
-        'different from the tokenizer used for training! This will result in '
-        'erroneous generation!')
-
-    # create model, load weights
-    config = OptimusConfig(vocab_size=vocab_sz,
-                           num_hidden_layers=n_layers,
-                           num_attention_heads=n_heads,
-                           hidden_size=dim,
-                           attention_dropout=p_drop,
-                           tie_word_embeddings=weight_tying)
+    # create model and load weights
+    config = OptimusConfig.from_json_file(model_config_path)
     model = OptimusTransformer(config)
-    model.load_state_dict(state, strict=True)
+    model.load_state_dict(torch.load(weights_path, map_location=device), strict=True)
+
+    # alternatively, comment the 3 lines above, and uncomment / change / adapt
+    # the following 3 lines to use a HuggingFace model
+    # from transformers import AutoTokenizer, AutoModelForCausalLM
+    # tokenizer = AutoTokenizer.from_pretrained('gpt2')
+    # model = AutoModelForCausalLM.from_pretrained('gpt2')
+
     model.eval()
+    model.to(device)
 
     logger.info(f'Loaded model on device {device}!')
 
@@ -147,7 +142,7 @@ def main(model_path: str = 'model.pth',
         input_sentence = input('User: ')
 
     # tokenize input
-    inp = torch.tensor(tokenizer(input_sentence), dtype=torch.long)
+    inp = torch.tensor(tokenizer(input_sentence)['input_ids'], dtype=torch.long)
     inp.unsqueeze_(0) # (1, seq_len)
 
     seq_len = inp.shape[-1]
@@ -160,7 +155,8 @@ def main(model_path: str = 'model.pth',
         while seq_len != max_seq_len:
 
             # get model output
-            output = model(inp) # (1, seq_len, vocab_sz)
+            output = model(inp)
+            output = output.get('logits', output) # (1, seq_len, vocab_sz)
             toks_generated += 1
 
             # get the logits for the last token
@@ -185,7 +181,7 @@ def main(model_path: str = 'model.pth',
 
             seq_len = inp.shape[-1]
 
-    logger.info(f"Model output: {' '.join(tokenizer.decode(inp.tolist()))}")
+    print(f"Model output: {''.join(tokenizer.decode(inp.squeeze(0)))}")
     logger.info(f'Tokens / second: {toks_generated / (time.time() - start_time):.2f}')
 
     logger.info('Finished inference!')
-- 
GitLab