Skip to content
Snippets Groups Projects
Unverified Commit 0ee3b108 authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Add option to pass a prompt to the inference script

This allows the inference code to start up with a prompt, instead of
waiting for user input from stdin. Allows easier scripting, useful for
batch generation, benchmarking etc.
parent 3ab5c4fd
No related branches found
No related tags found
1 merge request!15Add inference code
......@@ -34,6 +34,7 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
def main(model_path: str = 'model.pth',
tokenizer_path: str = 'optimus.model',
prompt: str | None = None,
max_seq_len: int = 512,
temperature: float = 0.6,
top_p: float = 0.9,
......@@ -47,6 +48,8 @@ def main(model_path: str = 'model.pth',
tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
tokenizer is used for both training and testing, otherwise the
results might be undefined.
prompt (str | None): Prompt to feed to the model. If empty, the user
will be prompted to enter text on stdin.
max_seq_len (int): Maximum context length of the model. Make sure the
value is similar to 'seq_len' used when training, otherwise the
model will generalize poorly to higher context lengths.
......@@ -66,6 +69,7 @@ def main(model_path: str = 'model.pth',
print(f"Running with:\n"
f"\t- model path: '{model_path}'\n"
f"\t- tokenizer path: '{tokenizer_path}'\n"
f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
f"\t- max context length: {max_seq_len}\n"
f"\t- temperature: {temperature}\n"
f"\t- top_p threshold: {top_p}\n"
......@@ -130,9 +134,11 @@ def main(model_path: str = 'model.pth',
# inference loop
print("Starting inference...")
print("Done! Waiting for user input... (prompt to complete)")
input_sentence = input("User: ")
if prompt is not None:
input_sentence = prompt
else:
print("Waiting for user input... (prompt to complete)")
input_sentence = input("User: ")
# tokenize input
inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment