From e6c82ba648c35db5467dce7c0343d263bddfd8c0 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 3 Jun 2024 22:44:10 +0300
Subject: [PATCH] Add DistributedDataParallel training

Add the most basic parallelism to the framework, through PyTorch's DDP.
Adjust the dataloaders to also use distributed samplers.

Add other goodies for distributed logging + distributed processing.
---
 optimus/trainer.py             | 20 +++++--
 optimus/utils/dist_utils.py    | 95 ++++++++++++++++++++++++++++++++++
 optimus/utils/logging_utils.py | 13 +++++
 optimus/utils/setup_utils.py   | 12 ++++-
 training.py                    | 68 +++++++++++++++++++-----
 5 files changed, 190 insertions(+), 18 deletions(-)
 create mode 100644 optimus/utils/dist_utils.py
 create mode 100644 optimus/utils/logging_utils.py

diff --git a/optimus/trainer.py b/optimus/trainer.py
index 2984f6b..00497ab 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -8,6 +8,8 @@ import torch.nn as nn
 from torch.utils.data import DataLoader
 from transformers.tokenization_utils_base import PreTrainedTokenizerBase
 
+from optimus.utils import dist_utils, logging_utils
+
 
 logger = logging.getLogger(__name__)
 
@@ -22,6 +24,8 @@ class TrainingArguments():
         log_steps (int): Log training progress each number of steps. This
             considers number of updates, so gradient_accumulation_steps
             influences this.
+        log_on_main_rank_only (bool, defaults to True): Whether the main rank
+            should be the only one that logs training information.
         show_progress (bool, defaults to True): Whether to show progress during
             training.
         seed (int): Seed used for reproducibility purposes.
@@ -57,10 +61,12 @@ class TrainingArguments():
         checkpoints_dir: Path,
         save_steps: int,
         save_limit: int,
+        log_on_main_rank_only: bool = True,
         show_progress: bool = True,
     ):
         self.device = device
         self.log_steps = log_steps
+        self.log_on_main_rank_only = log_on_main_rank_only
         self.show_progress = show_progress
         self.seed = seed
         self.optimizer = optimizer
@@ -106,6 +112,9 @@ class Trainer():
         self.eval_dataloader = eval_dataloader
         self.tokenizer = tokenizer
 
+        if self.args.log_on_main_rank_only:
+            logger.addFilter(logging_utils.FilterMainRankOnly())
+
     def train(self) -> None:
         """
         Training loop of the trainer.
@@ -114,13 +123,13 @@ class Trainer():
         num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
         num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
         max_steps = self.args.num_train_epochs * num_update_steps_per_epoch
-        global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * 1
+        global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size()
 
         fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
         loss_fn = torch.nn.CrossEntropyLoss()
 
-        self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress))
+        self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0)
 
         # scaler used for mixed precision fp16 training on GPU
         self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
@@ -152,7 +161,12 @@ class Trainer():
                     logits = self.model(inputs)
 
                     labels = inputs[..., 1:].contiguous().view(-1)
-                    logits = logits[..., :-1, :].contiguous().view(-1, self.model.vocab_size)
+                    logits = logits[..., :-1, :].contiguous().view(
+                        -1,
+                        self.model.module.vocab_size
+                        if dist_utils.is_dist_available_and_initialized()
+                        else self.model.vocab_size
+                    )
 
                     loss = loss_fn(logits, labels)
 
diff --git a/optimus/utils/dist_utils.py b/optimus/utils/dist_utils.py
new file mode 100644
index 0000000..de159f6
--- /dev/null
+++ b/optimus/utils/dist_utils.py
@@ -0,0 +1,95 @@
+import os
+import logging
+import contextlib
+
+import torch.distributed as dist
+
+
+logger = logging.getLogger(__name__)
+
+
+def init_process_group():
+    if 'RANK' not in os.environ or \
+        'WORLD_SIZE' not in os.environ or \
+        'MASTER_ADDR' not in os.environ or \
+        'MASTER_PORT' not in os.environ:
+        # at this point, we're pretty sure we are not in a distributed
+        # environment
+        pass
+    else:
+        dist.init_process_group('nccl')
+
+
+def destroy_process_group():
+    if 'RANK' not in os.environ or \
+        'WORLD_SIZE' not in os.environ or \
+        'MASTER_ADDR' not in os.environ or \
+        'MASTER_PORT' not in os.environ:
+        # at this point, we're pretty sure we are not in a distributed
+        # environment
+        pass
+    else:
+        dist.destroy_process_group()
+
+
+def is_dist_available_and_initialized():
+    if not dist.is_available():
+        return False
+
+    if not dist.is_initialized():
+        return False
+
+    return True
+
+
+def get_rank():
+    if not is_dist_available_and_initialized():
+        return 0
+
+    return dist.get_rank()
+
+
+def get_local_rank():
+    try:
+        # this should be automatically detected when running with torchrun
+        local_rank = int(os.environ['LOCAL_RANK'])
+    except:
+        # if you run by hand, make sure to correctly set ONLY ONE OF:
+        # - LOCAL_RANK=<x>
+        #   OR
+        # - CUDA_VISIBLE_DEVICES=<x>
+        # , where x is one of the GPU's available on the system
+        local_rank = 0
+    return local_rank
+
+
+def get_world_size():
+    if not is_dist_available_and_initialized():
+        return 1
+
+    return dist.get_world_size()
+
+
+def is_main_rank():
+    return get_rank() == 0
+
+
+@contextlib.contextmanager
+def main_process_first():
+    rank = get_rank()
+    if rank == 0:
+        try:
+            logger.info(f"Rank 0 start work")
+            yield
+        finally:
+            logger.info(f"Rank 0 done with work")
+            if is_dist_available_and_initialized():
+                dist.barrier(device_ids=[get_local_rank()])
+    else:
+        try:
+            logger.info(f"Rank {rank} waiting for rank 0 to process work")
+            if is_dist_available_and_initialized():
+                dist.barrier(device_ids=[get_local_rank()])
+        finally:
+            logger.info(f"Rank {rank} done waiting for rank 0")
+            yield
diff --git a/optimus/utils/logging_utils.py b/optimus/utils/logging_utils.py
new file mode 100644
index 0000000..f762da8
--- /dev/null
+++ b/optimus/utils/logging_utils.py
@@ -0,0 +1,13 @@
+import os
+import logging
+
+import torch.distributed as dist
+
+
+class FilterMainRankOnly(logging.Filter):
+    def filter(self, record):
+        if not dist.is_available():
+            return True
+        if not dist.is_initialized():
+            return True
+        return dist.get_rank() == 0
diff --git a/optimus/utils/setup_utils.py b/optimus/utils/setup_utils.py
index d50ad19..e4eb65d 100644
--- a/optimus/utils/setup_utils.py
+++ b/optimus/utils/setup_utils.py
@@ -1,9 +1,11 @@
 import logging
 
+from torch.nn.parallel import DistributedDataParallel as DDP
 from transformers import AutoTokenizer
 from datasets import load_from_disk
 
-from optimus.models.optimus import OptimusTransformer, OptimusConfig
+from optimus.models.optimus import OptimusConfig, OptimusTransformer
+from optimus.utils import dist_utils
 
 
 logger = logging.getLogger(__name__)
@@ -71,6 +73,14 @@ def create_model(config_file, device):
     model = OptimusTransformer(config)
     model.to(device)
 
+    if dist_utils.is_dist_available_and_initialized():
+        model = DDP(
+            model,
+            device_ids=[dist_utils.get_local_rank()],
+            output_device=dist_utils.get_local_rank(),
+            gradient_as_bucket_view=True,
+        )
+
     logger.info(model)
 
     _total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
diff --git a/training.py b/training.py
index b822fb1..9a6feea 100644
--- a/training.py
+++ b/training.py
@@ -4,10 +4,12 @@ from pathlib import Path
 
 import fire
 import torch
+import torch.distributed as dist
 from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
 
 from optimus.trainer import Trainer, TrainingArguments
-from optimus.utils import setup_utils
+from optimus.utils import dist_utils, logging_utils, setup_utils
 
 
 logger = logging.getLogger(__name__)
@@ -25,6 +27,7 @@ def main(
     grad_clip_norm: float = 1.0,
     use_fp16: bool = True,
     seed: int = 42,
+    log_on_main_rank_only: bool = True,
 
     # model + tokenizer
     model_config_path: Path = Path('config.json'),
@@ -63,6 +66,8 @@ def main(
             precision. Bfloat16 is used if available, otherwise fp16.
         seed (int): Seed used for reproducibility purposes. Each process will
             have its initial PyTorch seed set to `seed + process_rank`.
+        log_on_main_rank_only (bool): Whether output should only be printed by
+            the main rank (rank 0).
         model_config_path (Path): Path to a config file which describes the
             model to be trained.
         hf_tokenizer_name (str): HuggingFace tokenizer name.
@@ -85,6 +90,13 @@ def main(
             `data_dir`.
 
     """
+    # set up logging before anything else
+    logging.basicConfig(
+        level=os.environ.get('LOG_LEVEL', logging.INFO),
+        format='{asctime} - {filename} | {levelname} | {message}',
+        style='{',
+    )
+
     # create paths
     if isinstance(data_dir, str):
         data_dir = Path(data_dir)
@@ -94,14 +106,6 @@ def main(
     else:
         dataset_dir = data_dir / dataset_dir
 
-    # set up logging before anything else
-    logging.basicConfig(
-        level=os.environ.get('LOG_LEVEL', logging.INFO),
-        format='{asctime} - {filename} | {levelname} | {message}',
-        style='{',
-    )
-
-    # create paths
     if isinstance(model_config_path, str):
         model_config_path = Path(model_config_path)
 
@@ -139,6 +143,7 @@ def main(
         f"\t- gradient clipping norm: {grad_clip_norm}\n"
         f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
         f"\t- seed: {seed}\n"
+        f"\t- only main rank logs: {log_on_main_rank_only}\n"
         f"\t- model config file: {model_config_path}\n"
         f"\t- huggingface tokenizer: {hf_tokenizer_name}\n"
         f"\t- training data directory: {str(data_dir)}\n"
@@ -148,12 +153,26 @@ def main(
         f"\t- saved model directory: {str(save_dir)}\n"
         f"Please seek '--help' if you want to change any of these settings")
 
+    if log_on_main_rank_only:
+        logger.addFilter(logging_utils.FilterMainRankOnly())
+
+    # init distributed process group
+    dist_utils.init_process_group()
+    torch.manual_seed(seed + dist_utils.get_rank())
+
     # set device
-    device = f'cuda'
+    device = f'cuda:{dist_utils.get_local_rank()}'
+
+    logger.info(f'Rank {dist_utils.get_rank()}, world size {dist_utils.get_world_size()}, device {device}')
+
+    # set each process's GPU
+    torch.cuda.set_device(dist_utils.get_local_rank())
+    torch.cuda.empty_cache()
 
     # load dataset and split into batches
-    dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len)
-    dataset.set_format('torch')
+    with dist_utils.main_process_first():
+        dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len)
+        dataset.set_format('torch')
 
     # load tokenizer
     tokenizer = setup_utils.create_tokenizer(hf_tokenizer_name)
@@ -161,13 +180,30 @@ def main(
     # create model and move to device
     model = setup_utils.create_model(model_config_path, device)
 
-    # create dataloaders
+    # create samplers + dataloaders
+    train_dist_sampler = DistributedSampler(
+        dataset['train'],
+        num_replicas=dist_utils.get_world_size(),
+        rank=dist_utils.get_rank(),
+        shuffle=True,
+        seed=seed,
+        drop_last=True,
+    )
+
     train_dataloader = DataLoader(
         dataset['train'],
         batch_size=batch_size, # per GPU
         num_workers=6, # allow pre-fetching data through multi-process
         pin_memory=True, # fast CPU-GPU transfer
+        sampler=train_dist_sampler,
+    )
+
+    eval_dist_sampler = DistributedSampler(
+        dataset['test'],
+        num_replicas=dist_utils.get_world_size(),
+        rank=dist_utils.get_rank(),
         shuffle=True,
+        seed=seed,
         drop_last=True,
     )
 
@@ -176,7 +212,7 @@ def main(
         batch_size=batch_size, # per GPU
         num_workers=6, # allow pre-fetching data through multi-process
         pin_memory=True, # fast CPU-GPU transfer
-        drop_last=True,
+        sampler=eval_dist_sampler,
     )
 
     # create optimizer
@@ -201,6 +237,7 @@ def main(
 
         # logging
         log_steps=100,
+        log_on_main_rank_only=log_on_main_rank_only,
 
         # core training
         seed=seed,
@@ -239,6 +276,9 @@ def main(
     # save log data
     trainer.save_logs(log_dir)
 
+    # clean distributed process group
+    dist_utils.destroy_process_group()
+
 
 if __name__=='__main__':
     fire.Fire(main)
-- 
GitLab