Skip to content
Snippets Groups Projects

Add inference code

Closed Vlad-Andrei BĂDOIU (78692) requested to merge vladb/inferencev2 into main
2 files
+ 143
0
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 25
0
@@ -197,3 +197,28 @@ class Trainer():
f"\tTotal valid batches: {len(self.dl.test):10d} | "
f"Valid loss: {self.val_loss: 7.2f} | "
f"Valid perplexity: {self.val_ppl: 8.2f}")
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
idx_cond = idx
# TODO: Once we have access to context_size change the line to
# if idx.size(1) <= self.context_size else idx[:, -self.context_size:]
logits = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
Loading