diff --git a/inference.py b/inference.py index c0939cdd16e76f2553763019347e6a544bd9b40f..1606a2a93609896ced9e4d870e7c42b7a3c19f2f 100644 --- a/inference.py +++ b/inference.py @@ -34,6 +34,7 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor: def main(model_path: str = 'model.pth', tokenizer_path: str = 'optimus.model', + prompt: str | None = None, max_seq_len: int = 512, temperature: float = 0.6, top_p: float = 0.9, @@ -47,6 +48,8 @@ def main(model_path: str = 'model.pth', tokenizer_path (str): Path to the tokenizer. Note: Ensure the same tokenizer is used for both training and testing, otherwise the results might be undefined. + prompt (str | None): Prompt to feed to the model. If empty, the user + will be prompted to enter text on stdin. max_seq_len (int): Maximum context length of the model. Make sure the value is similar to 'seq_len' used when training, otherwise the model will generalize poorly to higher context lengths. @@ -66,6 +69,7 @@ def main(model_path: str = 'model.pth', print(f"Running with:\n" f"\t- model path: '{model_path}'\n" f"\t- tokenizer path: '{tokenizer_path}'\n" + f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n" f"\t- max context length: {max_seq_len}\n" f"\t- temperature: {temperature}\n" f"\t- top_p threshold: {top_p}\n" @@ -130,9 +134,11 @@ def main(model_path: str = 'model.pth', # inference loop print("Starting inference...") - print("Done! Waiting for user input... (prompt to complete)") - - input_sentence = input("User: ") + if prompt is not None: + input_sentence = prompt + else: + print("Waiting for user input... (prompt to complete)") + input_sentence = input("User: ") # tokenize input inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),