Skip to content
Snippets Groups Projects

Fix a number of issues with the infrastructure, no major rework

Merged Alexandru-Mihai GHERGHESCU requested to merge fix/general_small_fixes into main
1 unresolved thread
1 file
+ 18
2
Compare changes
  • Side-by-side
  • Inline
  • There was a corner case when the shape of the predictions y of the
    dataset would not be correct, due to the fact that the number of batches
    was miscalculated.
    
    This happened when `batch_len` was exactly divisible by `seq_len`, since
    the predictions, which are simply the text shifted once to the right,
    would not have that extra column at the end.
    
    Fix the above issue by decrementing the number of available batches with
    1 when `batch_len` exactly divides by `seq_len`.
+ 18
2
@@ -51,8 +51,24 @@ class _OptimusDL(Iterable):
print(f"Done. Took {time.time() - start:.2f}s.")
self.num_batches = torch.cat(self._data,
dim=-1).shape[0] // self.bs // self.seq_len
# pre-calculate the number of batches in the dataset
# Note: there's a special case we need to be careful about; since the
# predictions are simply the inputs shifted to the right by one value;
# there's a case when the dataset ends before we can get these
# shifted-right predictions; this occurs iff `batch_len % seq_len == 0`;
# to avoid this, we have to be explicit about the available number of
# batches (by simply subtracting 1 from the total number of available
# batches)
dataset_stream_len = 0
for sample in self._data:
Please register or sign in to reply
dataset_stream_len += len(sample)
batch_len = dataset_stream_len // self.bs
self.num_batches = batch_len // self.seq_len
if batch_len % self.seq_len == 0:
self.num_batches -= 1
def _process_data_before_iter(self):
data = self._data
Loading