diff --git a/README.md b/README.md
index d426eed8da1f7987403f508faace51e76d863046..72f6441766a3c11ef32757db0326b8052b7b8d37 100644
--- a/README.md
+++ b/README.md
@@ -12,8 +12,8 @@ to adapt as needed. Also see [Custom training](#custom-training).
 ### Inference
 
 After training a model (or getting hold of one from other sources), there's an
-example on how to run inference can be found in `inference.py`. Feel free to
-adapt as needed.
+example on how to run inference in `inference.py`. It uses nucleus sampling,
+with adjustable top-p threshold and temperature values.
 
 ## Basic building blocks
 
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a81ae2eb82c916f91c62d9d149d33cdcb2cabb
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,191 @@
+import os
+import time
+import warnings
+
+import fire
+import torch
+from torch import Tensor
+
+from optimus.tokenizers import SentencePieceTokenizer
+from optimus.models import OptimusTransformer
+
+
+def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
+    """
+    Nucleus (top-p) sampling.
+
+    Args:
+        probs (Tensor): Output probability distribution of the model. This
+            should be the softmax'ed logits.
+        top_p (float): Top-p threshold value for sampling.
+
+    Returns:
+        Tensor: Sampled output token, as index in the vocabulary.
+
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > top_p
+    probs_sort[mask] = 0.
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    token = torch.multinomial(probs_sort, num_samples=1)
+    token = torch.gather(probs_idx, -1, token)
+    return token
+
+
+def main(model_path: str = 'model.pth',
+         tokenizer_path: str = 'optimus.model',
+         prompt: str | None = None,
+         max_seq_len: int = 512,
+         temperature: float = 0.6,
+         top_p: float = 0.9,
+         use_fp16: bool = True,
+         device: str = 'cuda'):
+    """
+    Run the main inference loop for the model.
+
+    Args:
+        model_path (str): Path to the saved model.
+        tokenizer_path (str): Path to the tokenizer. Note: Ensure the same
+            tokenizer is used for both training and testing, otherwise the
+            results might be undefined.
+        prompt (str | None): Prompt to feed to the model. If empty, the user
+            will be prompted to enter text on stdin.
+        max_seq_len (int): Maximum context length of the model. Make sure the
+            value is similar to 'seq_len' used when training, otherwise the
+            model will generalize poorly to higher context lengths.
+        temperature (float): Temperature value to control randomness in
+            sampling.
+        top_p (float): Probability threshold used in nucleus sampling.
+        use_fp16 (bool): Whether the computation should be in fp16 or
+            full-precision fp32. If the model weights are stored as fp16, they
+            will be converted to fp32 and vice-versa. If 'use_fp16' is enabled,
+            but the model was trained with fp32, performance might drop
+            significantly.
+        device (str): Device where to run inference. Viable options are 'cpu',
+            'cuda', 'cuda:2' etc.
+
+    """
+
+    print(f"Running with:\n"
+        f"\t- model path: '{model_path}'\n"
+        f"\t- tokenizer path: '{tokenizer_path}'\n"
+        f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
+        f"\t- max context length: {max_seq_len}\n"
+        f"\t- temperature: {temperature}\n"
+        f"\t- top_p threshold: {top_p}\n"
+        f"\t- 16-bit floating-point: {use_fp16}\n"
+        f"\t- running inference on device: {device}\n"
+        f"Please see '--help' if you want to change these settings")
+
+    # load tokenizer
+    tok = SentencePieceTokenizer(model_path=tokenizer_path)
+
+    # set default tensor type for things like positional encodings
+    if use_fp16 is True:
+        if device == 'cuda':
+            # avoid warnings about set_default_tensor_type() being deprecated
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.cuda.HalfTensor)
+        elif device == 'cpu':
+            assert 0 == 1, "Cannot run 16-bit inference on CPU!"
+    else:
+        if device == 'cuda':
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.cuda.FloatTensor)
+        elif device == 'cpu':
+            warnings.simplefilter('ignore')
+            torch.set_default_tensor_type(torch.FloatTensor)
+
+    print("Loading model from disk...")
+
+    # load state from file
+    assert os.path.exists(model_path)
+    state = torch.load(model_path, map_location=device)
+
+    # these will be loaded into the model automatically later anyway, but we
+    # need at least `dim` now to be able to compute positional encodings
+    # correctly, and `vocab_sz` to make sure the tokenizer works
+    vocab_sz = int(state['vocab_sz'])
+    n_layers = int(state['n_layers'])
+    n_heads = int(state['n_heads'])
+    dim = int(state['dim'])
+    p_drop = float(state['p_drop'])
+    weight_tying = bool(state['weight_tying'])
+
+    assert vocab_sz == len(tok), ("The tokenizer passed for inference is "
+        "different from the tokenizer used for training! This will result in "
+        "erroneous generation!")
+
+    # create model, load weights
+    model = OptimusTransformer(vocab_sz=vocab_sz,
+                               n_layers=n_layers,
+                               n_heads=n_heads,
+                               dim=dim,
+                               p_drop=p_drop,
+                               weight_tying=weight_tying)
+    model.load_state_dict(state, strict=True)
+    model.eval()
+
+    print(f"Loaded model on device {device}!")
+
+    _total_params = sum(p.numel() for p in model.parameters())
+    print(f"Number of model parameters: {_total_params}")
+
+    # inference loop
+    print("Starting inference...")
+
+    if prompt is not None:
+        input_sentence = prompt
+    else:
+        print("Waiting for user input... (prompt to complete)")
+        input_sentence = input("User: ")
+
+    # tokenize input
+    inp = torch.tensor(tok.encode(input_sentence, bos=True, eos=False),
+                       dtype=torch.long)
+    inp.unsqueeze_(0) # (1, seq_len)
+
+    seq_len = inp.shape[-1]
+
+    toks_generated = 0
+    start_time = time.time()
+
+    with torch.inference_mode():
+
+        while seq_len != max_seq_len:
+
+            # get model output
+            output = model(inp) # (1, seq_len, vocab_sz)
+            toks_generated += 1
+
+            # get the logits for the last token
+            logits = output[0,-1,:] # (vocab_sz)
+
+            # sample the most relevant token from the probability
+            if temperature > 0:
+                # if temperature is enabled, nucleus (top-p) sampling
+                probs = torch.softmax(logits / temperature, dim=-1)
+                token = sample_top_p(probs, top_p)
+            else:
+                # otherwise, fallback to greedy decoding
+                token = torch.argmax(logits, dim=-1)
+                token.unsqueeze_(0)
+
+            # check if token is EOS
+            if token == tok.eos_id:
+                break
+
+            # append the token to the input
+            inp = torch.cat([inp, token.unsqueeze(0)], dim=-1) # (1, seq_len + 1)
+
+            seq_len = inp.shape[-1]
+
+    print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
+    print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
+
+    print("Finished inference!")
+
+
+if __name__=='__main__':
+    fire.Fire(main)