From cbf807dd9df8c72bf50dd5e71f7d3ecc7f9f5d27 Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Fri, 26 Jan 2024 20:38:01 +0200 Subject: [PATCH] Add tokens per second information Output model tokens per second at the end of inference. --- inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/inference.py b/inference.py index 1606a2a..45a81ae 100644 --- a/inference.py +++ b/inference.py @@ -1,4 +1,5 @@ import os +import time import warnings import fire @@ -147,12 +148,16 @@ def main(model_path: str = 'model.pth', seq_len = inp.shape[-1] + toks_generated = 0 + start_time = time.time() + with torch.inference_mode(): while seq_len != max_seq_len: # get model output output = model(inp) # (1, seq_len, vocab_sz) + toks_generated += 1 # get the logits for the last token logits = output[0,-1,:] # (vocab_sz) @@ -176,7 +181,8 @@ def main(model_path: str = 'model.pth', seq_len = inp.shape[-1] - print("Model output: ", ' '.join(tok.decode(inp.tolist()))) + print(f"Model output: {' '.join(tok.decode(inp.tolist()))}") + print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}") print("Finished inference!") -- GitLab