diff --git a/inference.py b/inference.py
index 1606a2a93609896ced9e4d870e7c42b7a3c19f2f..45a81ae2eb82c916f91c62d9d149d33cdcb2cabb 100644
--- a/inference.py
+++ b/inference.py
@@ -1,4 +1,5 @@
 import os
+import time
 import warnings
 
 import fire
@@ -147,12 +148,16 @@ def main(model_path: str = 'model.pth',
 
     seq_len = inp.shape[-1]
 
+    toks_generated = 0
+    start_time = time.time()
+
     with torch.inference_mode():
 
         while seq_len != max_seq_len:
 
             # get model output
             output = model(inp) # (1, seq_len, vocab_sz)
+            toks_generated += 1
 
             # get the logits for the last token
             logits = output[0,-1,:] # (vocab_sz)
@@ -176,7 +181,8 @@ def main(model_path: str = 'model.pth',
 
             seq_len = inp.shape[-1]
 
-    print("Model output: ", ' '.join(tok.decode(inp.tolist())))
+    print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
+    print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
 
     print("Finished inference!")