diff --git a/inference.py b/inference.py
index c0939cdd16e76f2553763019347e6a544bd9b40f..1606a2a93609896ced9e4d870e7c42b7a3c19f2f 100644
--- a/inference.py
+++ b/inference.py
@@ -34,6 +34,7 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
 
 def main(model_path: str = 'model.pth',
          tokenizer_path: str = 'optimus.model',
+         prompt: str | None = None,
          max_seq_len: int = 512,
          temperature: float = 0.6,
          top_p: float = 0.9,
@@ -47,6 +48,8 @@ def main(model_path: str = 'model.pth',
         tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
             tokenizer is used for both training and testing, otherwise the
             results might be undefined.
+        prompt (str | None): Prompt to feed to the model. If empty, the user
+            will be prompted to enter text on stdin.
         max_seq_len (int): Maximum context length of the model. Make sure the
             value is similar to 'seq_len' used when training, otherwise the
             model will generalize poorly to higher context lengths.
@@ -66,6 +69,7 @@ def main(model_path: str = 'model.pth',
     print(f"Running with:\n"
         f"\t- model path: '{model_path}'\n"
         f"\t- tokenizer path: '{tokenizer_path}'\n"
+        f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
         f"\t- max context length: {max_seq_len}\n"
         f"\t- temperature: {temperature}\n"
         f"\t- top_p threshold: {top_p}\n"
@@ -130,9 +134,11 @@ def main(model_path: str = 'model.pth',
     # inference loop
     print("Starting inference...")
 
-    print("Done! Waiting for user input... (prompt to complete)")
-
-    input_sentence = input("User: ")
+    if prompt is not None:
+        input_sentence = prompt
+    else:
+        print("Waiting for user input... (prompt to complete)")
+        input_sentence = input("User: ")
 
     # tokenize input
     inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),