diff --git a/optimus/dataloader.py b/optimus/dataloader.py
index 5539156b555e6eae1834836e84b2eeaaab5c0d6a..b1e00c961547fdd071d294743b8545636833f823 100644
--- a/optimus/dataloader.py
+++ b/optimus/dataloader.py
@@ -51,8 +51,24 @@ class _OptimusDL(Iterable):
 
         print(f"Done. Took {time.time() - start:.2f}s.")
 
-        self.num_batches = torch.cat(self._data,
-                                     dim=-1).shape[0] // self.bs // self.seq_len
+        # pre-calculate the number of batches in the dataset
+
+        # Note: there's a special case we need to be careful about; since the
+        # predictions are simply the inputs shifted to the right by one value;
+        # there's a case when the dataset ends before we can get these
+        # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`;
+        # to avoid this, we have to be explicit about the available number of
+        # batches (by simply subtracting 1 from the total number of available
+        # batches)
+        dataset_stream_len = 0
+        for sample in self._data:
+            dataset_stream_len += len(sample)
+
+        batch_len = dataset_stream_len // self.bs
+        self.num_batches = batch_len // self.seq_len
+
+        if batch_len % self.seq_len == 0:
+            self.num_batches -= 1
 
     def _process_data_before_iter(self):
         data = self._data