Skip to content
Snippets Groups Projects
Unverified Commit cbf807dd authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Add tokens per second information

Output model tokens per second at the end of inference.
parent 0ee3b108
No related branches found
No related tags found
1 merge request!15Add inference code
import os import os
import time
import warnings import warnings
import fire import fire
...@@ -147,12 +148,16 @@ def main(model_path: str = 'model.pth', ...@@ -147,12 +148,16 @@ def main(model_path: str = 'model.pth',
seq_len = inp.shape[-1] seq_len = inp.shape[-1]
toks_generated = 0
start_time = time.time()
with torch.inference_mode(): with torch.inference_mode():
while seq_len != max_seq_len: while seq_len != max_seq_len:
# get model output # get model output
output = model(inp) # (1, seq_len, vocab_sz) output = model(inp) # (1, seq_len, vocab_sz)
toks_generated += 1
# get the logits for the last token # get the logits for the last token
logits = output[0,-1,:] # (vocab_sz) logits = output[0,-1,:] # (vocab_sz)
...@@ -176,7 +181,8 @@ def main(model_path: str = 'model.pth', ...@@ -176,7 +181,8 @@ def main(model_path: str = 'model.pth',
seq_len = inp.shape[-1] seq_len = inp.shape[-1]
print("Model output: ", ' '.join(tok.decode(inp.tolist()))) print(f"Model output: {' '.join(tok.decode(inp.tolist()))}")
print(f"Tokens / second: {toks_generated / (time.time() - start_time):.2f}")
print("Finished inference!") print("Finished inference!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment