Skip to content
Snippets Groups Projects
Unverified Commit 3852611c authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Fix estimation interval

Fix a bug where the estimation interval would be 0. This only happened
for (very) small datasets, with gradient accumulation steps different
than 1.
parent 5bc6558f
No related branches found
No related tags found
1 merge request!16Fix estimation interval
...@@ -103,17 +103,18 @@ class Trainer(): ...@@ -103,17 +103,18 @@ class Trainer():
def _do_epoch_train(self): def _do_epoch_train(self):
self.model.train() # put model in training mode self.model.train() # put model in training mode
# compute average train loss, train ppl and ms/batch every ~200 batches, # compute average train loss, train perplexity and ms/batch every ~200
# or every 10% of training dataset (whichever is smaller), rounded to # batches, or every 10% of training dataset (whichever is smaller),
# gradient accumulation steps # rounded to gradient accumulation steps
self.ms_per_batch = 0. est_interval = int(max(min(200, 0.1 * len(self.dl.train)) // self.grad_acc_steps, 1) * self.grad_acc_steps)
total_loss = 0.
est_interval = int(max(min(200, 0.1 * len(self.dl.train)), 1)) // self.grad_acc_steps * self.grad_acc_steps
start_time = time.time()
# progress bar for batches # progress bar for batches
pb = progress_bar(range(len(self.dl.train)), parent=self.mb) pb = progress_bar(range(len(self.dl.train)), parent=self.mb)
self.ms_per_batch = 0.
total_loss = 0.
start_time = time.time()
for i, (x, y) in enumerate(self.dl.train): for i, (x, y) in enumerate(self.dl.train):
if self.progress_bar is True: if self.progress_bar is True:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment