diff --git a/README.md b/README.md index d426eed8da1f7987403f508faace51e76d863046..72f6441766a3c11ef32757db0326b8052b7b8d37 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,8 @@ to adapt as needed. Also see [Custom training](#custom-training). ### Inference After training a model (or getting hold of one from other sources), there's an -example on how to run inference can be found in `inference.py`. Feel free to -adapt as needed. +example on how to run inference in `inference.py`. It uses nucleus sampling, +with adjustable top-p threshold and temperature values. ## Basic building blocks diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..45a81ae2eb82c916f91c62d9d149d33cdcb2cabb --- /dev/null +++ b/inference.py @@ -0,0 +1,191 @@ +import os +import time +import warnings + +import fire +import torch +from torch import Tensor + +from optimus.tokenizers import SentencePieceTokenizer +from optimus.models import OptimusTransformer + + +def sample_top_p(probs: Tensor, top_p: float) -> Tensor: + """ + Nucleus (top-p) sampling. + + Args: + probs (Tensor): Output probability distribution of the model. This + should be the softmax'ed logits. + top_p (float): Top-p threshold value for sampling. + + Returns: + Tensor: Sampled output token, as index in the vocabulary. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + token = torch.multinomial(probs_sort, num_samples=1) + token = torch.gather(probs_idx, -1, token) + return token + + +def main(model_path: str = 'model.pth', + tokenizer_path: str = 'optimus.model', + prompt: str | None = None, + max_seq_len: int = 512, + temperature: float = 0.6, + top_p: float = 0.9, + use_fp16: bool = True, + device: str = 'cuda'): + """ + Run the main inference loop for the model. + + Args: + model_path (str): Path to the saved model. + tokenizer_path (str): Path to the tokenizer. Note: Ensure the same + tokenizer is used for both training and testing, otherwise the + results might be undefined. + prompt (str | None): Prompt to feed to the model. If empty, the user + will be prompted to enter text on stdin. + max_seq_len (int): Maximum context length of the model. Make sure the + value is similar to 'seq_len' used when training, otherwise the + model will generalize poorly to higher context lengths. + temperature (float): Temperature value to control randomness in + sampling. + top_p (float): Probability threshold used in nucleus sampling. + use_fp16 (bool): Whether the computation should be in fp16 or + full-precision fp32. If the model weights are stored as fp16, they + will be converted to fp32 and vice-versa. If 'use_fp16' is enabled, + but the model was trained with fp32, performance might drop + significantly. + device (str): Device where to run inference. Viable options are 'cpu', + 'cuda', 'cuda:2' etc. + + """ + + print(f"Running with:\n" + f"\t- model path: '{model_path}'\n" + f"\t- tokenizer path: '{tokenizer_path}'\n" + f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n" + f"\t- max context length: {max_seq_len}\n" + f"\t- temperature: {temperature}\n" + f"\t- top_p threshold: {top_p}\n" + f"\t- 16-bit floating-point: {use_fp16}\n" + f"\t- running inference on device: {device}\n" + f"Please see '--help' if you want to change these settings") + + # load tokenizer + tok = SentencePieceTokenizer(model_path=tokenizer_path) + + # set default tensor type for things like positional encodings + if use_fp16 is True: + if device == 'cuda': + # avoid warnings about set_default_tensor_type() being deprecated + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.cuda.HalfTensor) + elif device == 'cpu': + assert 0 == 1, "Cannot run 16-bit inference on CPU!" + else: + if device == 'cuda': + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.cuda.FloatTensor) + elif device == 'cpu': + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.FloatTensor) + + print("Loading model from disk...") + + # load state from file + assert os.path.exists(model_path) + state = torch.load(model_path, map_location=device) + + # these will be loaded into the model automatically later anyway, but we + # need at least `dim` now to be able to compute positional encodings + # correctly, and `vocab_sz` to make sure the tokenizer works + vocab_sz = int(state['vocab_sz']) + n_layers = int(state['n_layers']) + n_heads = int(state['n_heads']) + dim = int(state['dim']) + p_drop = float(state['p_drop']) + weight_tying = bool(state['weight_tying']) + + assert vocab_sz == len(tok), ("The tokenizer passed for inference is " + "different from the tokenizer used for training! This will result in " + "erroneous generation!") + + # create model, load weights + model = OptimusTransformer(vocab_sz=vocab_sz, + n_layers=n_layers, + n_heads=n_heads, + dim=dim, + p_drop=p_drop, + weight_tying=weight_tying) + model.load_state_dict(state, strict=True) + model.eval() + + print(f"Loaded model on device {device}!") + + _total_params = sum(p.numel() for p in model.parameters()) + print(f"Number of model parameters: {_total_params}") + + # inference loop + print("Starting inference...") + + if prompt is not None: + input_sentence = prompt + else: + print("Waiting for user input... (prompt to complete)") + input_sentence = input("User: ") + + # tokenize input + inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False), + dtype=torch.long) + inp.unsqueeze_(0) # (1, seq_len) + + seq_len = inp.shape[-1] + + toks_generated = 0 + start_time = time.time() + + with torch.inference_mode(): + + while seq_len != max_seq_len: + + # get model output + output = model(inp) # (1, seq_len, vocab_sz) + toks_generated += 1 + + # get the logits for the last token + logits = output[0,-1,:] # (vocab_sz) + + # sample the most relevant token from the probability + if temperature > 0: + # if temperature is enabled, nucleus (top-p) sampling + probs = torch.softmax(logits / temperature, dim=-1) + token = sample_top_p(probs, top_p) + else: + # otherwise, fallback to greedy decoding + token = torch.argmax(logits, dim=-1) + token.unsqueeze_(0) + + # check if token is EOS + if token == tok.eos_id: + break + + # append the token to the input + inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1) + + seq_len = inp.shape[-1] + + print(f"Model output: {' '.join(tok.decode(inp.tolist()))}") + print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}") + + print("Finished inference!") + + +if __name__=='__main__': + fire.Fire(main)