diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c0939cdd16e76f2553763019347e6a544bd9b40f --- /dev/null +++ b/inference.py @@ -0,0 +1,179 @@ +import os +import warnings + +import fire +import torch +from torch import Tensor + +from optimus.tokenizers import SentencePieceTokenizer +from optimus.models import OptimusTransformer + + +def sample_top_p(probs: Tensor, top_p: float) -> Tensor: + """ + Nucleus (top-p) sampling. + + Args: + probs (Tensor): Output probability distribution of the model. This + should be the softmax'ed logits. + top_p (float): Top-p threshold value for sampling. + + Returns: + Tensor: Sampled output token, as index in the vocabulary. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + token = torch.multinomial(probs_sort, num_samples=1) + token = torch.gather(probs_idx, -1, token) + return token + + +def main(model_path: str = 'model.pth', + tokenizer_path: str = 'optimus.model', + max_seq_len: int = 512, + temperature: float = 0.6, + top_p: float = 0.9, + use_fp16: bool = True, + device: str = 'cuda'): + """ + Run the main inference loop for the model. + + Args: + model_path (str): Path to the saved model. + tokenizer_path (str): Path to the tokenizer. Note: Ensure the same + tokenizer is used for both training and testing, otherwise the + results might be undefined. + max_seq_len (int): Maximum context length of the model. Make sure the + value is similar to 'seq_len' used when training, otherwise the + model will generalize poorly to higher context lengths. + temperature (float): Temperature value to control randomness in + sampling. + top_p (float): Probability threshold used in nucleus sampling. + use_fp16 (bool): Whether the computation should be in fp16 or + full-precision fp32. If the model weights are stored as fp16, they + will be converted to fp32 and vice-versa. If 'use_fp16' is enabled, + but the model was trained with fp32, performance might drop + significantly. + device (str): Device where to run inference. Viable options are 'cpu', + 'cuda', 'cuda:2' etc. + + """ + + print(f"Running with:\n" + f"\t- model path: '{model_path}'\n" + f"\t- tokenizer path: '{tokenizer_path}'\n" + f"\t- max context length: {max_seq_len}\n" + f"\t- temperature: {temperature}\n" + f"\t- top_p threshold: {top_p}\n" + f"\t- 16-bit floating-point: {use_fp16}\n" + f"\t- running inference on device: {device}\n" + f"Please see '--help' if you want to change these settings") + + # load tokenizer + tok = SentencePieceTokenizer(model_path=tokenizer_path) + + # set default tensor type for things like positional encodings + if use_fp16 is True: + if device == 'cuda': + # avoid warnings about set_default_tensor_type() being deprecated + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.cuda.HalfTensor) + elif device == 'cpu': + assert 0 == 1, "Cannot run 16-bit inference on CPU!" + else: + if device == 'cuda': + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.cuda.FloatTensor) + elif device == 'cpu': + warnings.simplefilter('ignore') + torch.set_default_tensor_type(torch.FloatTensor) + + print("Loading model from disk...") + + # load state from file + assert os.path.exists(model_path) + state = torch.load(model_path, map_location=device) + + # these will be loaded into the model automatically later anyway, but we + # need at least `dim` now to be able to compute positional encodings + # correctly, and `vocab_sz` to make sure the tokenizer works + vocab_sz = int(state['vocab_sz']) + n_layers = int(state['n_layers']) + n_heads = int(state['n_heads']) + dim = int(state['dim']) + p_drop = float(state['p_drop']) + weight_tying = bool(state['weight_tying']) + + assert vocab_sz == len(tok), ("The tokenizer passed for inference is " + "different from the tokenizer used for training! This will result in " + "erroneous generation!") + + # create model, load weights + model = OptimusTransformer(vocab_sz=vocab_sz, + n_layers=n_layers, + n_heads=n_heads, + dim=dim, + p_drop=p_drop, + weight_tying=weight_tying) + model.load_state_dict(state, strict=True) + model.eval() + + print(f"Loaded model on device {device}!") + + _total_params = sum(p.numel() for p in model.parameters()) + print(f"Number of model parameters: {_total_params}") + + # inference loop + print("Starting inference...") + + print("Done! Waiting for user input... (prompt to complete)") + + input_sentence = input("User: ") + + # tokenize input + inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False), + dtype=torch.long) + inp.unsqueeze_(0) # (1, seq_len) + + seq_len = inp.shape[-1] + + with torch.inference_mode(): + + while seq_len != max_seq_len: + + # get model output + output = model(inp) # (1, seq_len, vocab_sz) + + # get the logits for the last token + logits = output[0,-1,:] # (vocab_sz) + + # sample the most relevant token from the probability + if temperature > 0: + # if temperature is enabled, nucleus (top-p) sampling + probs = torch.softmax(logits / temperature, dim=-1) + token = sample_top_p(probs, top_p) + else: + # otherwise, fallback to greedy decoding + token = torch.argmax(logits, dim=-1) + token.unsqueeze_(0) + + # check if token is EOS + if token == tok.eos_id: + break + + # append the token to the input + inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1) + + seq_len = inp.shape[-1] + + print("Model output: ", ' '.join(tok.decode(inp.tolist()))) + + print("Finished inference!") + + +if __name__=='__main__': + fire.Fire(main)