Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • netsys/optimus-prime
1 result
Show changes
Commits on Source (5)
...@@ -12,8 +12,8 @@ to adapt as needed. Also see [Custom training](#custom-training). ...@@ -12,8 +12,8 @@ to adapt as needed. Also see [Custom training](#custom-training).
### Inference ### Inference
After training a model (or getting hold of one from other sources), there's an After training a model (or getting hold of one from other sources), there's an
example on how to run inference can be found in `inference.py`. Feel free to example on how to run inference in `inference.py`. It uses nucleus sampling,
adapt as needed. with adjustable top-p threshold and temperature values.
## Basic building blocks ## Basic building blocks
......
import os
import time
import warnings
import fire
import torch
from torch import Tensor
from optimus.tokenizers import SentencePieceTokenizer
from optimus.models import OptimusTransformer
def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
"""
Nucleus (top-p) sampling.
Args:
probs (Tensor): Output probability distribution of the model. This
should be the softmax'ed logits.
top_p (float): Top-p threshold value for sampling.
Returns:
Tensor: Sampled output token, as index in the vocabulary.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
token = torch.multinomial(probs_sort, num_samples=1)
token = torch.gather(probs_idx, -1, token)
return token
def main(model_path: str = 'model.pth',
tokenizer_path: str = 'optimus.model',
prompt: str | None = None,
max_seq_len: int = 512,
temperature: float = 0.6,
top_p: float = 0.9,
use_fp16: bool = True,
device: str = 'cuda'):
"""
Run the main inference loop for the model.
Args:
model_path (str): Path to the saved model.
tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
tokenizer is used for both training and testing, otherwise the
results might be undefined.
prompt (str | None): Prompt to feed to the model. If empty, the user
will be prompted to enter text on stdin.
max_seq_len (int): Maximum context length of the model. Make sure the
value is similar to 'seq_len' used when training, otherwise the
model will generalize poorly to higher context lengths.
temperature (float): Temperature value to control randomness in
sampling.
top_p (float): Probability threshold used in nucleus sampling.
use_fp16 (bool): Whether the computation should be in fp16 or
full-precision fp32. If the model weights are stored as fp16, they
will be converted to fp32 and vice-versa. If 'use_fp16' is enabled,
but the model was trained with fp32, performance might drop
significantly.
device (str): Device where to run inference. Viable options are 'cpu',
'cuda', 'cuda:2' etc.
"""
print(f"Running with:\n"
f"\t- model path: '{model_path}'\n"
f"\t- tokenizer path: '{tokenizer_path}'\n"
f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
f"\t- max context length: {max_seq_len}\n"
f"\t- temperature: {temperature}\n"
f"\t- top_p threshold: {top_p}\n"
f"\t- 16-bit floating-point: {use_fp16}\n"
f"\t- running inference on device: {device}\n"
f"Please see '--help' if you want to change these settings")
# load tokenizer
tok = SentencePieceTokenizer(model_path=tokenizer_path)
# set default tensor type for things like positional encodings
if use_fp16 is True:
if device == 'cuda':
# avoid warnings about set_default_tensor_type() being deprecated
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.HalfTensor)
elif device == 'cpu':
assert 0 == 1, "Cannot run 16-bit inference on CPU!"
else:
if device == 'cuda':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.cuda.FloatTensor)
elif device == 'cpu':
warnings.simplefilter('ignore')
torch.set_default_tensor_type(torch.FloatTensor)
print("Loading model from disk...")
# load state from file
assert os.path.exists(model_path)
state = torch.load(model_path, map_location=device)
# these will be loaded into the model automatically later anyway, but we
# need at least `dim` now to be able to compute positional encodings
# correctly, and `vocab_sz` to make sure the tokenizer works
vocab_sz = int(state['vocab_sz'])
n_layers = int(state['n_layers'])
n_heads = int(state['n_heads'])
dim = int(state['dim'])
p_drop = float(state['p_drop'])
weight_tying = bool(state['weight_tying'])
assert vocab_sz == len(tok), ("The tokenizer passed for inference is "
"different from the tokenizer used for training! This will result in "
"erroneous generation!")
# create model, load weights
model = OptimusTransformer(vocab_sz=vocab_sz,
n_layers=n_layers,
n_heads=n_heads,
dim=dim,
p_drop=p_drop,
weight_tying=weight_tying)
model.load_state_dict(state, strict=True)
model.eval()
print(f"Loaded model on device {device}!")
_total_params = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {_total_params}")
# inference loop
print("Starting inference...")
if prompt is not None:
input_sentence = prompt
else:
print("Waiting for user input... (prompt to complete)")
input_sentence = input("User: ")
# tokenize input
inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
dtype=torch.long)
inp.unsqueeze_(0) # (1, seq_len)
seq_len = inp.shape[-1]
toks_generated = 0
start_time = time.time()
with torch.inference_mode():
while seq_len != max_seq_len:
# get model output
output = model(inp) # (1, seq_len, vocab_sz)
toks_generated += 1
# get the logits for the last token
logits = output[0,-1,:] # (vocab_sz)
# sample the most relevant token from the probability
if temperature > 0:
# if temperature is enabled, nucleus (top-p) sampling
probs = torch.softmax(logits / temperature, dim=-1)
token = sample_top_p(probs, top_p)
else:
# otherwise, fallback to greedy decoding
token = torch.argmax(logits, dim=-1)
token.unsqueeze_(0)
# check if token is EOS
if token == tok.eos_id:
break
# append the token to the input
inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1)
seq_len = inp.shape[-1]
print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
print("Finished inference!")
if __name__=='__main__':
fire.Fire(main)