Skip to content
Snippets Groups Projects
Unverified Commit accab39c authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Add inference code

Inference example code. At the moment, the code simply loads a model
state file and generates text using that. Parameters like max sequence
length, whether training used fp16, what the tokenizer used for training
is etc., need to be passed manually by the user (there's a lot of room
for error here). To be improved.

Merges changes from !14
Closes !14
parent 9d1cb007
No related branches found
No related tags found
1 merge request!15Add inference code
import os
import warnings
import fire
import torch
from torch import Tensor
from optimus.tokenizers import SentencePieceTokenizer
from optimus.models import OptimusTransformer
def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
"""
Nucleus (top-p) sampling.
Args:
probs (Tensor): Output probability distribution of the model. This
should be the softmax'ed logits.
top_p (float): Top-p threshold value for sampling.
Returns:
Tensor: Sampled output token, as index in the vocabulary.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
token = torch.multinomial(probs_sort, num_samples=1)
token = torch.gather(probs_idx, -1, token)
return token
def main(model_path: str = 'model.pth',
tokenizer_path: str = 'optimus.model',
max_seq_len: int = 512,
temperature: float = 0.6,
top_p: float = 0.9,
use_fp16: bool = True,
device: str = 'cuda'):
"""
Run the main inference loop for the model.
Args:
model_path (str): Path to the saved model.
tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
tokenizer is used for both training and testing, otherwise the
results might be undefined.
max_seq_len (int): Maximum context length of the model. Make sure the
value is similar to 'seq_len' used when training, otherwise the
model will generalize poorly to higher context lengths.
temperature (float): Temperature value to control randomness in
sampling.
top_p (float): Probability threshold used in nucleus sampling.
use_fp16 (bool): Whether the computation should be in fp16 or
full-precision fp32. If the model weights are stored as fp16, they
will be converted to fp32 and vice-versa. If 'use_fp16' is enabled,
but the model was trained with fp32, performance might drop
significantly.
device (str): Device where to run inference. Viable options are 'cpu',
'cuda', 'cuda:2' etc.
"""
print(f"Running with:\n"
f"\t- model path: '{model_path}'\n"
f"\t- tokenizer path: '{tokenizer_path}'\n"
f"\t- max context length: {max_seq_len}\n"
f"\t- temperature: {temperature}\n"
f"\t- top_p threshold: {top_p}\n"
f"\t- 16-bit floating-point: {use_fp16}\n"
f"\t- running inference on device: {device}\n"
f"Please see '--help' if you want to change these settings")
# load tokenizer
tok = SentencePieceTokenizer(model_path=tokenizer_path)
# set default tensor type for things like positional encodings
if use_fp16 is True:
if device == 'cuda':
# avoid warnings about set_default_tensor_type() being deprecated
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.HalfTensor)
elif device == 'cpu':
assert 0 == 1, "Cannot run 16-bit inference on CPU!"
else:
if device == 'cuda':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.FloatTensor)
elif device == 'cpu':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.FloatTensor)
print("Loading model from disk...")
# load state from file
assert os.path.exists(model_path)
state = torch.load(model_path, map_location=device)
# these will be loaded into the model automatically later anyway, but we
# need at least `dim` now to be able to compute positional encodings
# correctly, and `vocab_sz` to make sure the tokenizer works
vocab_sz = int(state['vocab_sz'])
n_layers = int(state['n_layers'])
n_heads = int(state['n_heads'])
dim = int(state['dim'])
p_drop = float(state['p_drop'])
weight_tying = bool(state['weight_tying'])
assert vocab_sz == len(tok), ("The tokenizer passed for inference is "
"different from the tokenizer used for training! This will result in "
"erroneous generation!")
# create model, load weights
model = OptimusTransformer(vocab_sz=vocab_sz,
n_layers=n_layers,
n_heads=n_heads,
dim=dim,
p_drop=p_drop,
weight_tying=weight_tying)
model.load_state_dict(state, strict=True)
model.eval()
print(f"Loaded model on device {device}!")
_total_params = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {_total_params}")
# inference loop
print("Starting inference...")
print("Done! Waiting for user input... (prompt to complete)")
input_sentence = input("User: ")
# tokenize input
inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
dtype=torch.long)
inp.unsqueeze_(0) # (1, seq_len)
seq_len = inp.shape[-1]
with torch.inference_mode():
while seq_len != max_seq_len:
# get model output
output = model(inp) # (1, seq_len, vocab_sz)
# get the logits for the last token
logits = output[0,-1,:] # (vocab_sz)
# sample the most relevant token from the probability
if temperature > 0:
# if temperature is enabled, nucleus (top-p) sampling
probs = torch.softmax(logits / temperature, dim=-1)
token = sample_top_p(probs, top_p)
else:
# otherwise, fallback to greedy decoding
token = torch.argmax(logits, dim=-1)
token.unsqueeze_(0)
# check if token is EOS
if token == tok.eos_id:
break
# append the token to the input
inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1)
seq_len = inp.shape[-1]
print("Model output: ", ' '.join(tok.decode(inp.tolist())))
print("Finished inference!")
if __name__=='__main__':
fire.Fire(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment