From bf8752496fda21a863c687c1d4e3d5abac53302e Mon Sep 17 00:00:00 2001
From: Vlad-Andrei Badoiu <vlad_andrei.badoiu@upb.ro>
Date: Sun, 18 Feb 2024 22:36:57 +0000
Subject: [PATCH] Adapt optimus to Distributon

This commit adapts the existing code to use the distributed library via a config option.
To achieve this we switch to using Pytorch's dataloader.
---
 optimus/dataloader.py               | 185 ----------------------------
 optimus/datasets/prepare_dataset.py |   0
 optimus/datasets/tinystories.py     |  30 +++--
 optimus/trainer.py                  |  61 ++++++---
 training.py                         |  42 +++++--
 5 files changed, 95 insertions(+), 223 deletions(-)
 delete mode 100644 optimus/dataloader.py
 create mode 100644 optimus/datasets/prepare_dataset.py

diff --git a/optimus/dataloader.py b/optimus/dataloader.py
deleted file mode 100644
index 7bcd74e..0000000
--- a/optimus/dataloader.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import time
-import random
-from typing import Tuple, Iterator, Iterable
-
-import torch
-from torch import Tensor
-from torch.utils.data import Dataset
-
-from optimus.tokenizers import SentencePieceTokenizer
-
-
-class _OptimusDLIter(Iterator):
-    def __init__(self, dl):
-        """
-        _OptimusDL iterator.
-
-        """
-        self.dl = dl
-        self.curr = 0
-
-    def __next__(self) -> Tuple[Tensor, Tensor]:
-        if self.curr > len(self.dl) - 1:
-            raise StopIteration
-
-        x = self.dl._items[:, self.curr * self.dl.seq_len     : (self.curr + 1) * self.dl.seq_len]
-        y = self.dl._items[:, self.curr * self.dl.seq_len + 1 : (self.curr + 1) * self.dl.seq_len + 1]
-
-        self.curr += 1
-
-        return x, y
-
-class _OptimusDL(Iterable):
-    def __init__(self, ds, tok, bs, seq_len, shuffle, device):
-        """
-        See 'OptimusDataLoader'.
-
-        """
-        self.ds = ds
-        self.tok = tok
-        self.bs = bs
-        self.seq_len = seq_len
-        self.shuffle = shuffle
-        self.device = device
-
-        start = time.time()
-        print("Tokenizing dataset...")
-
-        # tokenize the dataset, add BOS/EOS tokens
-        self._data = [torch.tensor(self.tok.encode(item, bos=True, eos=True),
-                                   dtype=torch.long) for item in self.ds]
-
-        print(f"Done. Took {time.time() - start:.2f}s.")
-
-        # pre-calculate the number of batches in the dataset
-
-        # Note: there's a special case we need to be careful about; since the
-        # predictions are simply the inputs shifted to the right by one value;
-        # there's a case when the dataset ends before we can get these
-        # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`;
-        # to avoid this, we have to be explicit about the available number of
-        # batches (by simply subtracting 1 from the total number of available
-        # batches)
-        dataset_stream_len = 0
-        for sample in self._data:
-            dataset_stream_len += len(sample)
-
-        batch_len = dataset_stream_len // self.bs
-        self.num_batches = batch_len // self.seq_len
-
-        if batch_len % self.seq_len == 0:
-            self.num_batches -= 1
-
-    def _process_data_before_iter(self):
-        data = self._data
-
-        # shuffle text (this keeps information intact, as documents are
-        # shuffled, not the tokens inside of them!)
-        if self.shuffle:
-            random.shuffle(data)
-
-        # stack the whole dataset into one big line of text
-        data = torch.cat(data, dim=-1)
-
-        # make streams (drop elements which don't cleanly fit)
-        batch_len = data.shape[0] // self.bs
-        data = data[:batch_len * self.bs]
-        data = data.view(self.bs, batch_len).contiguous()
-
-        # move data to device
-        self._items = data.to(self.device)
-
-    def to(self, device) -> None:
-        """
-        See 'OptimusDataLoader.to(device)'.
-
-        """
-        self.device = device
-
-    def __iter__(self) -> Iterator[_OptimusDLIter]:
-        """
-        Return an iterator over the dataloader object.
-
-        """
-        self._process_data_before_iter()
-        return _OptimusDLIter(self)
-
-    def __len__(self) -> int:
-        return self.num_batches
-
-
-class OptimusDataLoader():
-
-    def __init__(self,
-                 train_ds: Dataset,
-                 test_ds: Dataset,
-                 tok: SentencePieceTokenizer,
-                 bs: int,
-                 seq_len: int,
-                 shuffle: bool = True,
-                 device: str | torch.device = 'cpu'):
-        """
-        A data loader specialized for training the Optimus model. This is mostly
-        whatever pytorch is doing with the default DataLoader, but has a few
-        extra bits specific to NLP tasks. Essentially, it is a stripped down
-        version of fastai's DataLoaders object.
-
-        The dataloader should receive text unformatted and untokenized. This is
-        all taken care of inside the dataloader. Shuffling, tokenization and
-        batching all happens when calling iter() on the dataloader's train and
-        test attributes. This should happen before each epoch while training.
-
-        Args:
-            train_ds (Dataset): The training dataset.
-            test_ds (Dataset): The testing dataset.
-            tok (SentencePieceTokenizer): Tokenizer to use.
-            bs (int): Batch size. The number of samples per batch to load.
-            seq_len (int): Sequence length. Also referred to as context length.
-            shuffle (bool): Whether to shuffle the training data before
-                training. True means shuffle before every epoch. Defaults to
-                True.
-            device (str | torch.device): The device where to put the data (CPU,
-                GPU or other devices). Defaults to CPU.
-
-        """
-        self.train = _OptimusDL(ds=train_ds,
-                                tok=tok,
-                                bs=bs,
-                                seq_len=seq_len,
-                                shuffle=shuffle,
-                                device=device)
-
-        self.test = _OptimusDL(ds=test_ds,
-                               tok=tok,
-                               bs=bs,
-                               seq_len=seq_len,
-                               shuffle=False,
-                               device=device)
-
-    def to(self, device: str | torch.device) -> 'OptimusDataLoader':
-        """
-        Move data to device. This creates a copy on the specified device (if the
-        device is different from the one the data currently resides on).
-
-        Args:
-            device(str | torch.device): Device to move data to.
-
-        Returns:
-            OptimusDataLoader: The dataloader moved to the new device.
-
-        """
-        self.train.to(device)
-        self.test.to(device)
-        return self
-
-    def cpu(self) -> 'OptimusDataLoader':
-        return self.to(device=torch.device('cpu'))
-
-    def cuda(self) -> 'OptimusDataLoader':
-        return self.to(device=torch.device('cuda'))
-
-    def __iter__(self):
-        raise TypeError("'OptimusDataLoader' is not iterable. Please use "
-                        "iter(OptimusDataLoader.train) or "
-                        "iter(OptimusDataLoader.test) to get an iterable for "
-                        "the train or test dataloaders, respectively")
diff --git a/optimus/datasets/prepare_dataset.py b/optimus/datasets/prepare_dataset.py
new file mode 100644
index 0000000..e69de29
diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py
index ba89567..bfc36a8 100644
--- a/optimus/datasets/tinystories.py
+++ b/optimus/datasets/tinystories.py
@@ -1,5 +1,7 @@
 import os
+import time
 
+import torch
 from torch.utils.data import Dataset, DataLoader
 
 from .dataset_utils import *
@@ -27,7 +29,7 @@ _EXTRACTED_FILES = {
 
 class TinyStoriesDataset(Dataset):
 
-    def __init__(self, root: str | None = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train', tokenizer = None):
         """
         TinyStories dataset.
 
@@ -74,7 +76,7 @@ class TinyStoriesDataset(Dataset):
         else:
             print(f"Found dataset at '{path}'. Using this for '{split}' split...")
 
-        self.stories = []
+        stories = []
 
         # open the dataset file and read the stories from it
         with open(path, 'r') as file:
@@ -86,26 +88,38 @@ class TinyStoriesDataset(Dataset):
                 if line == '<|endoftext|>\n':
                     # found the beginning of a story; save the previous one and
                     # begin building a new story
-                    self.stories.append(' '.join(story))
+                    stories.append(' '.join(story))
                     story = []
 
                 else:
                     # append the line to the story
                     story.append(line)
 
-        train_test_split = int(0.95 * len(self.stories))
+
+        start = time.time()
+        print("Tokenizing dataset...")
+
+        self.tok = tokenizer
+        # tokenize the dataset, add BOS/EOS tokens
+        self._data = [torch.tensor(tokenizer.encode(item, bos=True, eos=True),
+                                   dtype=torch.long) for item in stories]
+
+
+        print(f"Done. Took {time.time() - start:.2f}s.")
+
+        train_test_split = int(0.95 * len(self._data))
 
         if split == 'train':
-            self.stories = self.stories[:train_test_split]
+            self._data = self._data[:train_test_split]
         elif split == 'test':
-            self.stories = self.stories[train_test_split:]
+            self._data = self._data[train_test_split:]
 
     def __len__(self) -> int:
         """
         Return the length of the dataset, which is the total number of
         TinyStories stories contained in it.
         """
-        return len(self.stories)
+        return len(self._data)
 
     def __getitem__(self, idx: int) -> str:
         """
@@ -115,7 +129,7 @@ class TinyStoriesDataset(Dataset):
             idx (int): The index of the story in the dataset.
 
         """
-        return self.stories[idx]
+        return self._data[idx]
 
 
 if __name__=='__main__':
diff --git a/optimus/trainer.py b/optimus/trainer.py
index 0c0c381..5a8434a 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -7,14 +7,18 @@ import torch.nn as nn
 import torch.optim as optim
 from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
-from optimus.dataloader import OptimusDataLoader
+from torch.utils.data import DataLoader
+
 from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 
+from .distributon import Distributon
+
 class Trainer():
 
     def __init__(self,
-                 dl: OptimusDataLoader,
+                 train_loader: DataLoader,
+                 valid_loader: DataLoader,
                  model: nn.Module,
                  criterion: Callable,
                  optimizer: optim.Optimizer,
@@ -23,6 +27,7 @@ class Trainer():
                  grad_clip_norm: float,
                  model_save_path: str,
                  use_fp16: bool,
+                 distributon:Distributon = None,
                  progress_bar: bool = True):
         """
         Trainer implementation for Optimus models.
@@ -51,12 +56,20 @@ class Trainer():
                 is a console or a file.
 
         """
-        self.dl = dl
+        self.train_loader = train_loader
+        self.valid_loader = valid_loader
         self.model = model
         self.criterion = criterion
         self.optimizer = optimizer
         self.lr = lr
 
+
+        self.distributon = distributon
+
+        if distributon is not None:
+            self.model = self.distributon.setup_model(model)
+            self.optimizer = self.distributon.setup_optimizer(optimizer)
+
         assert type(grad_acc_steps) is int and grad_acc_steps > 0
         self.grad_acc_steps = grad_acc_steps
 
@@ -82,7 +95,7 @@ class Trainer():
             optimizer=self.optimizer,
             max_lr=self.lr,
             epochs=n_epochs,
-            steps_per_epoch=len(self.dl.train) // self.grad_acc_steps)
+            steps_per_epoch=len(self.train_loader) // self.grad_acc_steps)
 
         # scaler used for mixed precision fp16 training on GPU
         self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16)
@@ -117,16 +130,16 @@ class Trainer():
         # compute average train loss, train perplexity and ms/batch every ~200
         # batches, or every 10% of training dataset (whichever is smaller),
         # rounded to gradient accumulation steps
-        est_interval = int(max(min(200, 0.1 * len(self.dl.train)) // self.grad_acc_steps, 1) * self.grad_acc_steps)
+        est_interval = int(max(min(200, 0.1 * len(self.train_loader)) // self.grad_acc_steps, 1) * self.grad_acc_steps)
 
         # progress bar for batches
-        pb = progress_bar(range(len(self.dl.train)), parent=self.mb)
+        pb = progress_bar(range(len(self.train_loader)), parent=self.mb)
 
         self.ms_per_batch = 0.
         total_loss = 0.
         start_time = time.time()
 
-        for i, (x, y) in enumerate(self.dl.train):
+        for i, (x, y) in enumerate(self.train_loader):
 
             if self.progress_bar is True:
                 pb.update(i)
@@ -135,12 +148,15 @@ class Trainer():
             with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
                                          enabled=self.use_fp16):
                 output = self.model(x)
-                loss = self.criterion(output.view(-1, len(self.dl.train.tok)),
-                                      y.reshape(-1))
+                loss = self.criterion(output.view(-1, len(self.train_loader.dataset.tok)),
+                                  y.reshape(-1))
 
                 loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation
 
-            self.scaler.scale(loss).backward()
+            if self.distributon:
+                self.distributon.backward(loss)
+            else:
+                self.scaler.scale(loss).backward()
 
             total_loss += loss.item()
 
@@ -150,13 +166,18 @@ class Trainer():
                 # when the gradient accumulation steps are more than 1, and the
                 # number of batches doesn't cleanly divide by grad_acc_steps
 
-                # gradient clipping
-                self.scaler.unscale_(self.optimizer)
+                if not self.distributon:
+                    # gradient clipping
+                    self.scaler.unscale_(self.optimizer)
+
                 nn.utils.clip_grad_norm_(self.model.parameters(),
                                          max_norm=self.grad_clip_norm)
 
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
+                # TODO: Fix this for distributed
+                if not self.distributon:
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+
                 self.optimizer.zero_grad()
 
                 self.scheduler.step()
@@ -183,10 +204,10 @@ class Trainer():
         total_loss = 0.
 
         # progress bar for batches
-        pb = progress_bar(range(len(self.dl.test)), parent=self.mb)
+        pb = progress_bar(range(len(self.valid_loader)), parent=self.mb)
 
         with torch.no_grad():
-            for i, (x, y) in enumerate(self.dl.test):
+            for i, (x, y) in enumerate(self.valid_loader):
 
                 if self.progress_bar is True:
                     pb.update(i)
@@ -194,14 +215,14 @@ class Trainer():
                 with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
                                              enabled=self.use_fp16):
                     output = self.model(x)
-                    loss = self.criterion(output.view(-1, len(self.dl.test.tok)),
+                    loss = self.criterion(output.view(-1, len(self.valid_loader.dataset.tok)),
                                           y.reshape(-1))
 
                 total_loss += loss.item()
 
                 self.mb.child.comment = f" | valid loss: {loss.item():.4f}"
 
-            self.val_loss = total_loss / (len(self.dl.test) - 1)
+            self.val_loss = total_loss / (len(self.valid_loader) - 1)
             self.val_ppl = math.exp(self.val_loss)
 
             pb.on_iter_end()
@@ -213,9 +234,9 @@ class Trainer():
                 f"* End of epoch {self.epoch:3d}:\n"
                 f"\tTotal time: {epoch_time:9s} | "
                 f"Est. ms/batch: {self.ms_per_batch:.2f}\n"
-                f"\tTotal train batches: {len(self.dl.train):10d} | "
+                f"\tTotal train batches: {len(self.train_loader):10d} | "
                 f"Train loss: {self.train_loss: 7.2f} | "
                 f"Train perplexity: {self.train_ppl: 8.2f}\n"
-                f"\tTotal valid batches: {len(self.dl.test):10d} | "
+                f"\tTotal valid batches: {len(self.valid_loader):10d} | "
                 f"Valid loss: {self.val_loss: 7.2f} | "
                 f"Valid perplexity: {self.val_ppl: 8.2f}")
diff --git a/training.py b/training.py
index 64cdf8f..f0c5998 100644
--- a/training.py
+++ b/training.py
@@ -2,12 +2,13 @@ import fire
 import torch
 from torch import nn
 
-from optimus.datasets import WikiText103Dataset
+from optimus.datasets import TinyStoriesDataset
 from optimus.tokenizers import SentencePieceTokenizer
-from optimus.dataloader import OptimusDataLoader
+from optimus.distributon.dataloader import build_dataloader
 from optimus.models import OptimusTransformer
 from optimus.trainer import Trainer
 
+from optimus.distributon import Distributon
 
 def main(batch_size: int = 8,
          grad_acc_steps: int = 1,
@@ -21,7 +22,8 @@ def main(batch_size: int = 8,
          n_layers: int = 6,
          n_heads: int = 8,
          dropout: float = 0.0,
-         use_fp16: bool = True):
+         use_fp16: bool = True,
+         distributed: bool = False):
     """
     Run the main training loop for the model.
 
@@ -61,21 +63,37 @@ def main(batch_size: int = 8,
         f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
         f"Please see '--help' if you want to change these settings")
 
+    
+    # Launch the distributed proccesses
+    if distributed:
+        distributon = Distributon([f"cuda:{i}" for i in range(torch.cuda.device_count())])
+        distributon.launch()
+        device = distributon._strategy.root_device
+    else:
+        distributon = None
+
     # load tokenizer
     tok = SentencePieceTokenizer(model_path=tokenizer_path)
 
     # load dataset splits
-    train_ds = WikiText103Dataset(split='train')
-    test_ds = WikiText103Dataset(split='test')
+    train_ds = TinyStoriesDataset(split='train', tokenizer=tok)
+    test_ds = TinyStoriesDataset(split='test', tokenizer=tok)
 
     print(f"Number of examples in training set: {len(train_ds)}")
     print(f"Number of examples in testing set: {len(test_ds)}")
 
-    # create dataloader object and move to device
-    dl = OptimusDataLoader(train_ds, test_ds, tok,
+    # create the dataloaders
+    train_loader = build_dataloader(train_ds,
                            bs=batch_size,
                            seq_len=seq_len,
-                           device='cuda')
+                           device=device,
+                           distributed=distributed)
+
+    valid_loader = build_dataloader(train_ds,
+                        bs=batch_size,
+                        seq_len=seq_len,
+                        device=device,
+                        distributed=distributed)
 
     # create model and move to device
     model = OptimusTransformer(len(tok),
@@ -84,7 +102,9 @@ def main(batch_size: int = 8,
                                n_heads=n_heads,
                                p_drop=dropout,
                                weight_tying=False)
-    model = model.to('cuda')
+
+    if not distributed:
+        model = model.to('cuda')
 
     _total_params = sum(p.numel() for p in model.parameters())
     print(f"Number of model parameters: {_total_params}")
@@ -101,7 +121,8 @@ def main(batch_size: int = 8,
     print("Starting training...")
 
     # create trainer and start fitting
-    trainer = Trainer(dl=dl,
+    trainer = Trainer(train_loader=train_loader,
+                      valid_loader=valid_loader,
                       model=model,
                       criterion=criterion,
                       optimizer=optimizer,
@@ -110,6 +131,7 @@ def main(batch_size: int = 8,
                       grad_clip_norm=grad_clip_norm,
                       model_save_path=checkpoints_path,
                       use_fp16=use_fp16,
+                      distributon=distributon,
                       progress_bar=True)
     trainer.fit(epochs)
 
-- 
GitLab