diff --git a/optimus/dataloader.py b/optimus/dataloader.py index 5539156b555e6eae1834836e84b2eeaaab5c0d6a..b1e00c961547fdd071d294743b8545636833f823 100644 --- a/optimus/dataloader.py +++ b/optimus/dataloader.py @@ -51,8 +51,24 @@ class _OptimusDL(Iterable): print(f"Done. Took {time.time() - start:.2f}s.") - self.num_batches = torch.cat(self._data, - dim=-1).shape[0] // self.bs // self.seq_len + # pre-calculate the number of batches in the dataset + + # Note: there's a special case we need to be careful about; since the + # predictions are simply the inputs shifted to the right by one value; + # there's a case when the dataset ends before we can get these + # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`; + # to avoid this, we have to be explicit about the available number of + # batches (by simply subtracting 1 from the total number of available + # batches) + dataset_stream_len = 0 + for sample in self._data: + dataset_stream_len += len(sample) + + batch_len = dataset_stream_len // self.bs + self.num_batches = batch_len // self.seq_len + + if batch_len % self.seq_len == 0: + self.num_batches -= 1 def _process_data_before_iter(self): data = self._data