From 7aa99b4aef7c66710bb827e77478c55871f7434f Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 22 Jan 2024 21:51:01 +0200
Subject: [PATCH] Fix bad calculation for number of batches

There was a corner case when the shape of the predictions y of the
dataset would not be correct, due to the fact that the number of batches
was miscalculated.

This happened when `batch_len` was exactly divisible by `seq_len`, since
the predictions, which are simply the text shifted once to the right,
would not have that extra column at the end.

Fix the above issue by decrementing the number of available batches with
1 when `batch_len` exactly divides by `seq_len`.
---
 optimus/dataloader.py | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/optimus/dataloader.py b/optimus/dataloader.py
index 5539156..b1e00c9 100644
--- a/optimus/dataloader.py
+++ b/optimus/dataloader.py
@@ -51,8 +51,24 @@ class _OptimusDL(Iterable):
 
         print(f"Done. Took {time.time() - start:.2f}s.")
 
-        self.num_batches = torch.cat(self._data,
-                                     dim=-1).shape[0] // self.bs // self.seq_len
+        # pre-calculate the number of batches in the dataset
+
+        # Note: there's a special case we need to be careful about; since the
+        # predictions are simply the inputs shifted to the right by one value;
+        # there's a case when the dataset ends before we can get these
+        # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`;
+        # to avoid this, we have to be explicit about the available number of
+        # batches (by simply subtracting 1 from the total number of available
+        # batches)
+        dataset_stream_len = 0
+        for sample in self._data:
+            dataset_stream_len += len(sample)
+
+        batch_len = dataset_stream_len // self.bs
+        self.num_batches = batch_len // self.seq_len
+
+        if batch_len % self.seq_len == 0:
+            self.num_batches -= 1
 
     def _process_data_before_iter(self):
         data = self._data
-- 
GitLab