From cbf807dd9df8c72bf50dd5e71f7d3ecc7f9f5d27 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Fri, 26 Jan 2024 20:38:01 +0200
Subject: [PATCH] Add tokens per second information

Output model tokens per second at the end of inference.
---
 inference.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/inference.py b/inference.py
index 1606a2a..45a81ae 100644
--- a/inference.py
+++ b/inference.py
@@ -1,4 +1,5 @@
 import os
+import time
 import warnings
 
 import fire
@@ -147,12 +148,16 @@ def main(model_path: str = 'model.pth',
 
     seq_len = inp.shape[-1]
 
+    toks_generated = 0
+    start_time = time.time()
+
     with torch.inference_mode():
 
         while seq_len != max_seq_len:
 
             # get model output
             output = model(inp) # (1, seq_len, vocab_sz)
+            toks_generated += 1
 
             # get the logits for the last token
             logits = output[0,-1,:] # (vocab_sz)
@@ -176,7 +181,8 @@ def main(model_path: str = 'model.pth',
 
             seq_len = inp.shape[-1]
 
-    print("Model output: ", ' '.join(tok.decode(inp.tolist())))
+    print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
+    print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
 
     print("Finished inference!")
 
-- 
GitLab