Skip to content
Snippets Groups Projects
Commit 5bc6558f authored by Vlad-Andrei BĂDOIU (78692)'s avatar Vlad-Andrei BĂDOIU (78692)
Browse files

Merge branch 'feature/inference' into 'main'

Add inference code

See merge request !15
parents 9d1cb007 cbf807dd
No related branches found
No related tags found
1 merge request!15Add inference code
......@@ -12,8 +12,8 @@ to adapt as needed. Also see [Custom training](#custom-training).
### Inference
After training a model (or getting hold of one from other sources), there's an
example on how to run inference can be found in `inference.py`. Feel free to
adapt as needed.
example on how to run inference in `inference.py`. It uses nucleus sampling,
with adjustable top-p threshold and temperature values.
## Basic building blocks
......
import os
import time
import warnings
import fire
import torch
from torch import Tensor
from optimus.tokenizers import SentencePieceTokenizer
from optimus.models import OptimusTransformer
def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
"""
Nucleus (top-p) sampling.
Args:
probs (Tensor): Output probability distribution of the model. This
should be the softmax'ed logits.
top_p (float): Top-p threshold value for sampling.
Returns:
Tensor: Sampled output token, as index in the vocabulary.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
token = torch.multinomial(probs_sort, num_samples=1)
token = torch.gather(probs_idx, -1, token)
return token
def main(model_path: str = 'model.pth',
tokenizer_path: str = 'optimus.model',
prompt: str | None = None,
max_seq_len: int = 512,
temperature: float = 0.6,
top_p: float = 0.9,
use_fp16: bool = True,
device: str = 'cuda'):
"""
Run the main inference loop for the model.
Args:
model_path (str): Path to the saved model.
tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
tokenizer is used for both training and testing, otherwise the
results might be undefined.
prompt (str | None): Prompt to feed to the model. If empty, the user
will be prompted to enter text on stdin.
max_seq_len (int): Maximum context length of the model. Make sure the
value is similar to 'seq_len' used when training, otherwise the
model will generalize poorly to higher context lengths.
temperature (float): Temperature value to control randomness in
sampling.
top_p (float): Probability threshold used in nucleus sampling.
use_fp16 (bool): Whether the computation should be in fp16 or
full-precision fp32. If the model weights are stored as fp16, they
will be converted to fp32 and vice-versa. If 'use_fp16' is enabled,
but the model was trained with fp32, performance might drop
significantly.
device (str): Device where to run inference. Viable options are 'cpu',
'cuda', 'cuda:2' etc.
"""
print(f"Running with:\n"
f"\t- model path: '{model_path}'\n"
f"\t- tokenizer path: '{tokenizer_path}'\n"
f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
f"\t- max context length: {max_seq_len}\n"
f"\t- temperature: {temperature}\n"
f"\t- top_p threshold: {top_p}\n"
f"\t- 16-bit floating-point: {use_fp16}\n"
f"\t- running inference on device: {device}\n"
f"Please see '--help' if you want to change these settings")
# load tokenizer
tok = SentencePieceTokenizer(model_path=tokenizer_path)
# set default tensor type for things like positional encodings
if use_fp16 is True:
if device == 'cuda':
# avoid warnings about set_default_tensor_type() being deprecated
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.HalfTensor)
elif device == 'cpu':
assert 0 == 1, "Cannot run 16-bit inference on CPU!"
else:
if device == 'cuda':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.FloatTensor)
elif device == 'cpu':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.FloatTensor)
print("Loading model from disk...")
# load state from file
assert os.path.exists(model_path)
state = torch.load(model_path, map_location=device)
# these will be loaded into the model automatically later anyway, but we
# need at least `dim` now to be able to compute positional encodings
# correctly, and `vocab_sz` to make sure the tokenizer works
vocab_sz = int(state['vocab_sz'])
n_layers = int(state['n_layers'])
n_heads = int(state['n_heads'])
dim = int(state['dim'])
p_drop = float(state['p_drop'])
weight_tying = bool(state['weight_tying'])
assert vocab_sz == len(tok), ("The tokenizer passed for inference is "
"different from the tokenizer used for training! This will result in "
"erroneous generation!")
# create model, load weights
model = OptimusTransformer(vocab_sz=vocab_sz,
n_layers=n_layers,
n_heads=n_heads,
dim=dim,
p_drop=p_drop,
weight_tying=weight_tying)
model.load_state_dict(state, strict=True)
model.eval()
print(f"Loaded model on device {device}!")
_total_params = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {_total_params}")
# inference loop
print("Starting inference...")
if prompt is not None:
input_sentence = prompt
else:
print("Waiting for user input... (prompt to complete)")
input_sentence = input("User: ")
# tokenize input
inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
dtype=torch.long)
inp.unsqueeze_(0) # (1, seq_len)
seq_len = inp.shape[-1]
toks_generated = 0
start_time = time.time()
with torch.inference_mode():
while seq_len != max_seq_len:
# get model output
output = model(inp) # (1, seq_len, vocab_sz)
toks_generated += 1
# get the logits for the last token
logits = output[0,-1,:] # (vocab_sz)
# sample the most relevant token from the probability
if temperature > 0:
# if temperature is enabled, nucleus (top-p) sampling
probs = torch.softmax(logits / temperature, dim=-1)
token = sample_top_p(probs, top_p)
else:
# otherwise, fallback to greedy decoding
token = torch.argmax(logits, dim=-1)
token.unsqueeze_(0)
# check if token is EOS
if token == tok.eos_id:
break
# append the token to the input
inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1)
seq_len = inp.shape[-1]
print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
print("Finished inference!")
if __name__=='__main__':
fire.Fire(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment