From 7aa99b4aef7c66710bb827e77478c55871f7434f Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Mon, 22 Jan 2024 21:51:01 +0200 Subject: [PATCH] Fix bad calculation for number of batches There was a corner case when the shape of the predictions y of the dataset would not be correct, due to the fact that the number of batches was miscalculated. This happened when `batch_len` was exactly divisible by `seq_len`, since the predictions, which are simply the text shifted once to the right, would not have that extra column at the end. Fix the above issue by decrementing the number of available batches with 1 when `batch_len` exactly divides by `seq_len`. --- optimus/dataloader.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/optimus/dataloader.py b/optimus/dataloader.py index 5539156..b1e00c9 100644 --- a/optimus/dataloader.py +++ b/optimus/dataloader.py @@ -51,8 +51,24 @@ class _OptimusDL(Iterable): print(f"Done. Took {time.time() - start:.2f}s.") - self.num_batches = torch.cat(self._data, - dim=-1).shape[0] // self.bs // self.seq_len + # pre-calculate the number of batches in the dataset + + # Note: there's a special case we need to be careful about; since the + # predictions are simply the inputs shifted to the right by one value; + # there's a case when the dataset ends before we can get these + # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`; + # to avoid this, we have to be explicit about the available number of + # batches (by simply subtracting 1 from the total number of available + # batches) + dataset_stream_len = 0 + for sample in self._data: + dataset_stream_len += len(sample) + + batch_len = dataset_stream_len // self.bs + self.num_batches = batch_len // self.seq_len + + if batch_len % self.seq_len == 0: + self.num_batches -= 1 def _process_data_before_iter(self): data = self._data -- GitLab