From 0ee3b1082d5c2281df99b5d88532603453974b4f Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Fri, 26 Jan 2024 20:36:17 +0200 Subject: [PATCH] Add option to pass a prompt to the inference script This allows the inference code to start up with a prompt, instead of waiting for user input from stdin. Allows easier scripting, useful for batch generation, benchmarking etc. --- inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/inference.py b/inference.py index c0939cd..1606a2a 100644 --- a/inference.py +++ b/inference.py @@ -34,6 +34,7 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor: def main(model_path: str = 'model.pth', tokenizer_path: str = 'optimus.model', + prompt: str | None = None, max_seq_len: int = 512, temperature: float = 0.6, top_p: float = 0.9, @@ -47,6 +48,8 @@ def main(model_path: str = 'model.pth', tokenizer_path (str): Path to the tokenizer. Note: Ensure the same tokenizer is used for both training and testing, otherwise the results might be undefined. + prompt (str | None): Prompt to feed to the model. If empty, the user + will be prompted to enter text on stdin. max_seq_len (int): Maximum context length of the model. Make sure the value is similar to 'seq_len' used when training, otherwise the model will generalize poorly to higher context lengths. @@ -66,6 +69,7 @@ def main(model_path: str = 'model.pth', print(f"Running with:\n" f"\t- model path: '{model_path}'\n" f"\t- tokenizer path: '{tokenizer_path}'\n" + f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n" f"\t- max context length: {max_seq_len}\n" f"\t- temperature: {temperature}\n" f"\t- top_p threshold: {top_p}\n" @@ -130,9 +134,11 @@ def main(model_path: str = 'model.pth', # inference loop print("Starting inference...") - print("Done! Waiting for user input... (prompt to complete)") - - input_sentence = input("User: ") + if prompt is not None: + input_sentence = prompt + else: + print("Waiting for user input... (prompt to complete)") + input_sentence = input("User: ") # tokenize input inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False), -- GitLab