From 3852611c11ac77ea58e8c3a90ca30c3ecd0a0f98 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Fri, 26 Jan 2024 22:58:06 +0200
Subject: [PATCH] Fix estimation interval

Fix a bug where the estimation interval would be 0. This only happened
for (very) small datasets, with gradient accumulation steps different
than 1.
---
 optimus/trainer.py | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/optimus/trainer.py b/optimus/trainer.py
index fd547ba..5854714 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -103,17 +103,18 @@ class Trainer():
     def _do_epoch_train(self):
         self.model.train() # put model in training mode
 
-        # compute average train loss, train ppl and ms/batch every ~200 batches,
-        # or every 10% of training dataset (whichever is smaller), rounded to
-        # gradient accumulation steps
-        self.ms_per_batch = 0.
-        total_loss = 0.
-        est_interval = int(max(min(200, 0.1 * len(self.dl.train)), 1)) // self.grad_acc_steps * self.grad_acc_steps
-        start_time = time.time()
+        # compute average train loss, train perplexity and ms/batch every ~200
+        # batches, or every 10% of training dataset (whichever is smaller),
+        # rounded to gradient accumulation steps
+        est_interval = int(max(min(200, 0.1 * len(self.dl.train)) // self.grad_acc_steps, 1) * self.grad_acc_steps)
 
         # progress bar for batches
         pb = progress_bar(range(len(self.dl.train)), parent=self.mb)
 
+        self.ms_per_batch = 0.
+        total_loss = 0.
+        start_time = time.time()
+
         for i, (x, y) in enumerate(self.dl.train):
 
             if self.progress_bar is True:
-- 
GitLab