diff --git a/optimus/dataloader.py b/optimus/dataloader.py deleted file mode 100644 index 7bcd74ee52ce2d5f5e35215db751ca6f899a1709..0000000000000000000000000000000000000000 --- a/optimus/dataloader.py +++ /dev/null @@ -1,185 +0,0 @@ -import time -import random -from typing import Tuple, Iterator, Iterable - -import torch -from torch import Tensor -from torch.utils.data import Dataset - -from optimus.tokenizers import SentencePieceTokenizer - - -class _OptimusDLIter(Iterator): - def __init__(self, dl): - """ - _OptimusDL iterator. - - """ - self.dl = dl - self.curr = 0 - - def __next__(self) -> Tuple[Tensor, Tensor]: - if self.curr > len(self.dl) - 1: - raise StopIteration - - x = self.dl._items[:, self.curr * self.dl.seq_len : (self.curr + 1) * self.dl.seq_len] - y = self.dl._items[:, self.curr * self.dl.seq_len + 1 : (self.curr + 1) * self.dl.seq_len + 1] - - self.curr += 1 - - return x, y - -class _OptimusDL(Iterable): - def __init__(self, ds, tok, bs, seq_len, shuffle, device): - """ - See 'OptimusDataLoader'. - - """ - self.ds = ds - self.tok = tok - self.bs = bs - self.seq_len = seq_len - self.shuffle = shuffle - self.device = device - - start = time.time() - print("Tokenizing dataset...") - - # tokenize the dataset, add BOS/EOS tokens - self._data = [torch.tensor(self.tok.encode(item, bos=True, eos=True), - dtype=torch.long) for item in self.ds] - - print(f"Done. Took {time.time() - start:.2f}s.") - - # pre-calculate the number of batches in the dataset - - # Note: there's a special case we need to be careful about; since the - # predictions are simply the inputs shifted to the right by one value; - # there's a case when the dataset ends before we can get these - # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`; - # to avoid this, we have to be explicit about the available number of - # batches (by simply subtracting 1 from the total number of available - # batches) - dataset_stream_len = 0 - for sample in self._data: - dataset_stream_len += len(sample) - - batch_len = dataset_stream_len // self.bs - self.num_batches = batch_len // self.seq_len - - if batch_len % self.seq_len == 0: - self.num_batches -= 1 - - def _process_data_before_iter(self): - data = self._data - - # shuffle text (this keeps information intact, as documents are - # shuffled, not the tokens inside of them!) - if self.shuffle: - random.shuffle(data) - - # stack the whole dataset into one big line of text - data = torch.cat(data, dim=-1) - - # make streams (drop elements which don't cleanly fit) - batch_len = data.shape[0] // self.bs - data = data[:batch_len * self.bs] - data = data.view(self.bs, batch_len).contiguous() - - # move data to device - self._items = data.to(self.device) - - def to(self, device) -> None: - """ - See 'OptimusDataLoader.to(device)'. - - """ - self.device = device - - def __iter__(self) -> Iterator[_OptimusDLIter]: - """ - Return an iterator over the dataloader object. - - """ - self._process_data_before_iter() - return _OptimusDLIter(self) - - def __len__(self) -> int: - return self.num_batches - - -class OptimusDataLoader(): - - def __init__(self, - train_ds: Dataset, - test_ds: Dataset, - tok: SentencePieceTokenizer, - bs: int, - seq_len: int, - shuffle: bool = True, - device: str | torch.device = 'cpu'): - """ - A data loader specialized for training the Optimus model. This is mostly - whatever pytorch is doing with the default DataLoader, but has a few - extra bits specific to NLP tasks. Essentially, it is a stripped down - version of fastai's DataLoaders object. - - The dataloader should receive text unformatted and untokenized. This is - all taken care of inside the dataloader. Shuffling, tokenization and - batching all happens when calling iter() on the dataloader's train and - test attributes. This should happen before each epoch while training. - - Args: - train_ds (Dataset): The training dataset. - test_ds (Dataset): The testing dataset. - tok (SentencePieceTokenizer): Tokenizer to use. - bs (int): Batch size. The number of samples per batch to load. - seq_len (int): Sequence length. Also referred to as context length. - shuffle (bool): Whether to shuffle the training data before - training. True means shuffle before every epoch. Defaults to - True. - device (str | torch.device): The device where to put the data (CPU, - GPU or other devices). Defaults to CPU. - - """ - self.train = _OptimusDL(ds=train_ds, - tok=tok, - bs=bs, - seq_len=seq_len, - shuffle=shuffle, - device=device) - - self.test = _OptimusDL(ds=test_ds, - tok=tok, - bs=bs, - seq_len=seq_len, - shuffle=False, - device=device) - - def to(self, device: str | torch.device) -> 'OptimusDataLoader': - """ - Move data to device. This creates a copy on the specified device (if the - device is different from the one the data currently resides on). - - Args: - device(str | torch.device): Device to move data to. - - Returns: - OptimusDataLoader: The dataloader moved to the new device. - - """ - self.train.to(device) - self.test.to(device) - return self - - def cpu(self) -> 'OptimusDataLoader': - return self.to(device=torch.device('cpu')) - - def cuda(self) -> 'OptimusDataLoader': - return self.to(device=torch.device('cuda')) - - def __iter__(self): - raise TypeError("'OptimusDataLoader' is not iterable. Please use " - "iter(OptimusDataLoader.train) or " - "iter(OptimusDataLoader.test) to get an iterable for " - "the train or test dataloaders, respectively") diff --git a/optimus/datasets/prepare_dataset.py b/optimus/datasets/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py index ba895677b3f3a5418f3730615346efbeb7b52198..bfc36a8000ea707779f64ea984f1e80cffbeca82 100644 --- a/optimus/datasets/tinystories.py +++ b/optimus/datasets/tinystories.py @@ -1,5 +1,7 @@ import os +import time +import torch from torch.utils.data import Dataset, DataLoader from .dataset_utils import * @@ -27,7 +29,7 @@ _EXTRACTED_FILES = { class TinyStoriesDataset(Dataset): - def __init__(self, root: str | None = None, split: str = 'train'): + def __init__(self, root: str | None = None, split: str = 'train', tokenizer = None): """ TinyStories dataset. @@ -74,7 +76,7 @@ class TinyStoriesDataset(Dataset): else: print(f"Found dataset at '{path}'. Using this for '{split}' split...") - self.stories = [] + stories = [] # open the dataset file and read the stories from it with open(path, 'r') as file: @@ -86,26 +88,38 @@ class TinyStoriesDataset(Dataset): if line == '<|endoftext|>\n': # found the beginning of a story; save the previous one and # begin building a new story - self.stories.append(' '.join(story)) + stories.append(' '.join(story)) story = [] else: # append the line to the story story.append(line) - train_test_split = int(0.95 * len(self.stories)) + + start = time.time() + print("Tokenizing dataset...") + + self.tok = tokenizer + # tokenize the dataset, add BOS/EOS tokens + self._data = [torch.tensor(tokenizer.encode(item, bos=True, eos=True), + dtype=torch.long) for item in stories] + + + print(f"Done. Took {time.time() - start:.2f}s.") + + train_test_split = int(0.95 * len(self._data)) if split == 'train': - self.stories = self.stories[:train_test_split] + self._data = self._data[:train_test_split] elif split == 'test': - self.stories = self.stories[train_test_split:] + self._data = self._data[train_test_split:] def __len__(self) -> int: """ Return the length of the dataset, which is the total number of TinyStories stories contained in it. """ - return len(self.stories) + return len(self._data) def __getitem__(self, idx: int) -> str: """ @@ -115,7 +129,7 @@ class TinyStoriesDataset(Dataset): idx (int): The index of the story in the dataset. """ - return self.stories[idx] + return self._data[idx] if __name__=='__main__': diff --git a/optimus/trainer.py b/optimus/trainer.py index 0c0c38119ff11dbc74a14e58d7e79ec1a859b9d6..5a8434a3d4cad2d3c8670951c17b0677cd864f34 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -7,14 +7,18 @@ import torch.nn as nn import torch.optim as optim from fastprogress.fastprogress import master_bar, progress_bar, format_time -from optimus.dataloader import OptimusDataLoader +from torch.utils.data import DataLoader + from fastprogress.fastprogress import master_bar, progress_bar, format_time +from .distributon import Distributon + class Trainer(): def __init__(self, - dl: OptimusDataLoader, + train_loader: DataLoader, + valid_loader: DataLoader, model: nn.Module, criterion: Callable, optimizer: optim.Optimizer, @@ -23,6 +27,7 @@ class Trainer(): grad_clip_norm: float, model_save_path: str, use_fp16: bool, + distributon:Distributon = None, progress_bar: bool = True): """ Trainer implementation for Optimus models. @@ -51,12 +56,20 @@ class Trainer(): is a console or a file. """ - self.dl = dl + self.train_loader = train_loader + self.valid_loader = valid_loader self.model = model self.criterion = criterion self.optimizer = optimizer self.lr = lr + + self.distributon = distributon + + if distributon is not None: + self.model = self.distributon.setup_model(model) + self.optimizer = self.distributon.setup_optimizer(optimizer) + assert type(grad_acc_steps) is int and grad_acc_steps > 0 self.grad_acc_steps = grad_acc_steps @@ -82,7 +95,7 @@ class Trainer(): optimizer=self.optimizer, max_lr=self.lr, epochs=n_epochs, - steps_per_epoch=len(self.dl.train) // self.grad_acc_steps) + steps_per_epoch=len(self.train_loader) // self.grad_acc_steps) # scaler used for mixed precision fp16 training on GPU self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16) @@ -117,16 +130,16 @@ class Trainer(): # compute average train loss, train perplexity and ms/batch every ~200 # batches, or every 10% of training dataset (whichever is smaller), # rounded to gradient accumulation steps - est_interval = int(max(min(200, 0.1 * len(self.dl.train)) // self.grad_acc_steps, 1) * self.grad_acc_steps) + est_interval = int(max(min(200, 0.1 * len(self.train_loader)) // self.grad_acc_steps, 1) * self.grad_acc_steps) # progress bar for batches - pb = progress_bar(range(len(self.dl.train)), parent=self.mb) + pb = progress_bar(range(len(self.train_loader)), parent=self.mb) self.ms_per_batch = 0. total_loss = 0. start_time = time.time() - for i, (x, y) in enumerate(self.dl.train): + for i, (x, y) in enumerate(self.train_loader): if self.progress_bar is True: pb.update(i) @@ -135,12 +148,15 @@ class Trainer(): with torch.cuda.amp.autocast(dtype=self.fp16_dtype, enabled=self.use_fp16): output = self.model(x) - loss = self.criterion(output.view(-1, len(self.dl.train.tok)), - y.reshape(-1)) + loss = self.criterion(output.view(-1, len(self.train_loader.dataset.tok)), + y.reshape(-1)) loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation - self.scaler.scale(loss).backward() + if self.distributon: + self.distributon.backward(loss) + else: + self.scaler.scale(loss).backward() total_loss += loss.item() @@ -150,13 +166,18 @@ class Trainer(): # when the gradient accumulation steps are more than 1, and the # number of batches doesn't cleanly divide by grad_acc_steps - # gradient clipping - self.scaler.unscale_(self.optimizer) + if not self.distributon: + # gradient clipping + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) - self.scaler.step(self.optimizer) - self.scaler.update() + # TODO: Fix this for distributed + if not self.distributon: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() self.scheduler.step() @@ -183,10 +204,10 @@ class Trainer(): total_loss = 0. # progress bar for batches - pb = progress_bar(range(len(self.dl.test)), parent=self.mb) + pb = progress_bar(range(len(self.valid_loader)), parent=self.mb) with torch.no_grad(): - for i, (x, y) in enumerate(self.dl.test): + for i, (x, y) in enumerate(self.valid_loader): if self.progress_bar is True: pb.update(i) @@ -194,14 +215,14 @@ class Trainer(): with torch.cuda.amp.autocast(dtype=self.fp16_dtype, enabled=self.use_fp16): output = self.model(x) - loss = self.criterion(output.view(-1, len(self.dl.test.tok)), + loss = self.criterion(output.view(-1, len(self.valid_loader.dataset.tok)), y.reshape(-1)) total_loss += loss.item() self.mb.child.comment = f" | valid loss: {loss.item():.4f}" - self.val_loss = total_loss / (len(self.dl.test) - 1) + self.val_loss = total_loss / (len(self.valid_loader) - 1) self.val_ppl = math.exp(self.val_loss) pb.on_iter_end() @@ -213,9 +234,9 @@ class Trainer(): f"* End of epoch {self.epoch:3d}:\n" f"\tTotal time: {epoch_time:9s} | " f"Est. ms/batch: {self.ms_per_batch:.2f}\n" - f"\tTotal train batches: {len(self.dl.train):10d} | " + f"\tTotal train batches: {len(self.train_loader):10d} | " f"Train loss: {self.train_loss: 7.2f} | " f"Train perplexity: {self.train_ppl: 8.2f}\n" - f"\tTotal valid batches: {len(self.dl.test):10d} | " + f"\tTotal valid batches: {len(self.valid_loader):10d} | " f"Valid loss: {self.val_loss: 7.2f} | " f"Valid perplexity: {self.val_ppl: 8.2f}") diff --git a/training.py b/training.py index 64cdf8f57d96c7bdc19e532c74aa67a6fa6035fe..f0c5998009444213b876e6d824c61a182ab42d26 100644 --- a/training.py +++ b/training.py @@ -2,12 +2,13 @@ import fire import torch from torch import nn -from optimus.datasets import WikiText103Dataset +from optimus.datasets import TinyStoriesDataset from optimus.tokenizers import SentencePieceTokenizer -from optimus.dataloader import OptimusDataLoader +from optimus.distributon.dataloader import build_dataloader from optimus.models import OptimusTransformer from optimus.trainer import Trainer +from optimus.distributon import Distributon def main(batch_size: int = 8, grad_acc_steps: int = 1, @@ -21,7 +22,8 @@ def main(batch_size: int = 8, n_layers: int = 6, n_heads: int = 8, dropout: float = 0.0, - use_fp16: bool = True): + use_fp16: bool = True, + distributed: bool = False): """ Run the main training loop for the model. @@ -61,21 +63,37 @@ def main(batch_size: int = 8, f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"Please see '--help' if you want to change these settings") + + # Launch the distributed proccesses + if distributed: + distributon = Distributon([f"cuda:{i}" for i in range(torch.cuda.device_count())]) + distributon.launch() + device = distributon._strategy.root_device + else: + distributon = None + # load tokenizer tok = SentencePieceTokenizer(model_path=tokenizer_path) # load dataset splits - train_ds = WikiText103Dataset(split='train') - test_ds = WikiText103Dataset(split='test') + train_ds = TinyStoriesDataset(split='train', tokenizer=tok) + test_ds = TinyStoriesDataset(split='test', tokenizer=tok) print(f"Number of examples in training set: {len(train_ds)}") print(f"Number of examples in testing set: {len(test_ds)}") - # create dataloader object and move to device - dl = OptimusDataLoader(train_ds, test_ds, tok, + # create the dataloaders + train_loader = build_dataloader(train_ds, bs=batch_size, seq_len=seq_len, - device='cuda') + device=device, + distributed=distributed) + + valid_loader = build_dataloader(train_ds, + bs=batch_size, + seq_len=seq_len, + device=device, + distributed=distributed) # create model and move to device model = OptimusTransformer(len(tok), @@ -84,7 +102,9 @@ def main(batch_size: int = 8, n_heads=n_heads, p_drop=dropout, weight_tying=False) - model = model.to('cuda') + + if not distributed: + model = model.to('cuda') _total_params = sum(p.numel() for p in model.parameters()) print(f"Number of model parameters: {_total_params}") @@ -101,7 +121,8 @@ def main(batch_size: int = 8, print("Starting training...") # create trainer and start fitting - trainer = Trainer(dl=dl, + trainer = Trainer(train_loader=train_loader, + valid_loader=valid_loader, model=model, criterion=criterion, optimizer=optimizer, @@ -110,6 +131,7 @@ def main(batch_size: int = 8, grad_clip_norm=grad_clip_norm, model_save_path=checkpoints_path, use_fp16=use_fp16, + distributon=distributon, progress_bar=True) trainer.fit(epochs)