Skip to content
Snippets Groups Projects
Unverified Commit cb1a7974 authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Fix inference code

This should now work with any PyTorch model (Optimus is the example
given in the source code), as well as any HuggingFace model (adjusted
the code to be independent of any model source).
parent 8e97649e
No related branches found
No related tags found
1 merge request!25Re-factor optimus-prime code (optimus-prime v2)
Pipeline #72358 passed
import os
import time import time
import warnings import warnings
import logging import logging
from pathlib import Path
import fire import fire
import torch import torch
...@@ -37,19 +37,28 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor: ...@@ -37,19 +37,28 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
return token return token
def main(model_path: str = 'model.pth', def main(
hf_tokenizer_name: str = 'gpt2', # model + tokenizer
prompt: str | None = None, model_config_path: Path = Path('config.json'),
max_seq_len: int = 512, weights_path: Path = Path('weights.pth'),
temperature: float = 0.6, hf_tokenizer_name: str = 'gpt2',
top_p: float = 0.9,
use_fp16: bool = True, # prompt
device: str = 'cuda'): prompt: str | None = None,
# inference args
max_seq_len: int = 512,
temperature: float = 0.6,
top_p: float = 0.9,
use_fp16: bool = True,
device: str = 'cuda'
):
""" """
Run the main inference loop for the model. Run the main inference loop for the model.
Args: Args:
model_path (str): Path to the saved model. model_config_path (Path): Path to config for model.
weights_path (Path): Path to trained model weights.
hf_tokenizer_name (str): HuggingFace tokenizer. Note: Ensure the same hf_tokenizer_name (str): HuggingFace tokenizer. Note: Ensure the same
tokenizer is used for both training and testing, otherwise the tokenizer is used for both training and testing, otherwise the
results might be undefined. results might be undefined.
...@@ -72,7 +81,8 @@ def main(model_path: str = 'model.pth', ...@@ -72,7 +81,8 @@ def main(model_path: str = 'model.pth',
""" """
logger.info(f"Running with:\n" logger.info(f"Running with:\n"
f"\t- model path: '{model_path}'\n" f"\t- model config path: '{model_config_path}'\n"
f"\t- weights file: '{weights_path}'\n"
f"\t- huggingface tokenizer: '{hf_tokenizer_name}'\n" f"\t- huggingface tokenizer: '{hf_tokenizer_name}'\n"
f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n" f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
f"\t- max context length: {max_seq_len}\n" f"\t- max context length: {max_seq_len}\n"
...@@ -103,34 +113,19 @@ def main(model_path: str = 'model.pth', ...@@ -103,34 +113,19 @@ def main(model_path: str = 'model.pth',
logger.info('Loading model from disk...') logger.info('Loading model from disk...')
# load state from file # create model and load weights
assert os.path.exists(model_path) config = OptimusConfig.from_json_file(model_config_path)
state = torch.load(model_path, map_location=device)
# these will be loaded into the model automatically later anyway, but we
# need at least `dim` now to be able to compute positional encodings
# correctly, and `vocab_sz` to make sure the tokenizer works
vocab_sz = int(state['vocab_sz'])
n_layers = int(state['n_layers'])
n_heads = int(state['n_heads'])
dim = int(state['dim'])
p_drop = float(state['p_drop'])
weight_tying = bool(state['weight_tying'])
assert vocab_sz == len(tokenizer), ('The tokenizer passed for inference is '
'different from the tokenizer used for training! This will result in '
'erroneous generation!')
# create model, load weights
config = OptimusConfig(vocab_size=vocab_sz,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
hidden_size=dim,
attention_dropout=p_drop,
tie_word_embeddings=weight_tying)
model = OptimusTransformer(config) model = OptimusTransformer(config)
model.load_state_dict(state, strict=True) model.load_state_dict(torch.load(weights_path, map_location=device), strict=True)
# alternatively, comment the 3 lines above, and uncomment / change / adapt
# the following 3 lines to use a HuggingFace model
# from transformers import AutoTokenizer, AutoModelForCausalLM
# tokenizer = AutoTokenizer.from_pretrained('gpt2')
# model = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval() model.eval()
model.to(device)
logger.info(f'Loaded model on device {device}!') logger.info(f'Loaded model on device {device}!')
...@@ -147,7 +142,7 @@ def main(model_path: str = 'model.pth', ...@@ -147,7 +142,7 @@ def main(model_path: str = 'model.pth',
input_sentence = input('User: ') input_sentence = input('User: ')
# tokenize input # tokenize input
inp = torch.tensor(tokenizer(input_sentence), dtype=torch.long) inp = torch.tensor(tokenizer(input_sentence)['input_ids'], dtype=torch.long)
inp.unsqueeze_(0) # (1, seq_len) inp.unsqueeze_(0) # (1, seq_len)
seq_len = inp.shape[-1] seq_len = inp.shape[-1]
...@@ -160,7 +155,8 @@ def main(model_path: str = 'model.pth', ...@@ -160,7 +155,8 @@ def main(model_path: str = 'model.pth',
while seq_len != max_seq_len: while seq_len != max_seq_len:
# get model output # get model output
output = model(inp) # (1, seq_len, vocab_sz) output = model(inp)
output = output.get('logits', output) # (1, seq_len, vocab_sz)
toks_generated += 1 toks_generated += 1
# get the logits for the last token # get the logits for the last token
...@@ -185,7 +181,7 @@ def main(model_path: str = 'model.pth', ...@@ -185,7 +181,7 @@ def main(model_path: str = 'model.pth',
seq_len = inp.shape[-1] seq_len = inp.shape[-1]
logger.info(f"Model output: {' '.join(tokenizer.decode(inp.tolist()))}") print(f"Model output: {''.join(tokenizer.decode(inp.squeeze(0)))}")
logger.info(f'Tokens / second: {toks_generated / (time.time() - start_time):.2f}') logger.info(f'Tokens / second: {toks_generated / (time.time() - start_time):.2f}')
logger.info('Finished inference!') logger.info('Finished inference!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment