From e6c82ba648c35db5467dce7c0343d263bddfd8c0 Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Mon, 3 Jun 2024 22:44:10 +0300 Subject: [PATCH] Add DistributedDataParallel training Add the most basic parallelism to the framework, through PyTorch's DDP. Adjust the dataloaders to also use distributed samplers. Add other goodies for distributed logging + distributed processing. --- optimus/trainer.py | 20 +++++-- optimus/utils/dist_utils.py | 95 ++++++++++++++++++++++++++++++++++ optimus/utils/logging_utils.py | 13 +++++ optimus/utils/setup_utils.py | 12 ++++- training.py | 68 +++++++++++++++++++----- 5 files changed, 190 insertions(+), 18 deletions(-) create mode 100644 optimus/utils/dist_utils.py create mode 100644 optimus/utils/logging_utils.py diff --git a/optimus/trainer.py b/optimus/trainer.py index 2984f6b..00497ab 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -8,6 +8,8 @@ import torch.nn as nn from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from optimus.utils import dist_utils, logging_utils + logger = logging.getLogger(__name__) @@ -22,6 +24,8 @@ class TrainingArguments(): log_steps (int): Log training progress each number of steps. This considers number of updates, so gradient_accumulation_steps influences this. + log_on_main_rank_only (bool, defaults to True): Whether the main rank + should be the only one that logs training information. show_progress (bool, defaults to True): Whether to show progress during training. seed (int): Seed used for reproducibility purposes. @@ -57,10 +61,12 @@ class TrainingArguments(): checkpoints_dir: Path, save_steps: int, save_limit: int, + log_on_main_rank_only: bool = True, show_progress: bool = True, ): self.device = device self.log_steps = log_steps + self.log_on_main_rank_only = log_on_main_rank_only self.show_progress = show_progress self.seed = seed self.optimizer = optimizer @@ -106,6 +112,9 @@ class Trainer(): self.eval_dataloader = eval_dataloader self.tokenizer = tokenizer + if self.args.log_on_main_rank_only: + logger.addFilter(logging_utils.FilterMainRankOnly()) + def train(self) -> None: """ Training loop of the trainer. @@ -114,13 +123,13 @@ class Trainer(): num_examples = len(self.train_dataloader) * self.args.per_device_batch_size num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps max_steps = self.args.num_train_epochs * num_update_steps_per_epoch - global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * 1 + global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size() fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 loss_fn = torch.nn.CrossEntropyLoss() - self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress)) + self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0) # scaler used for mixed precision fp16 training on GPU self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16) @@ -152,7 +161,12 @@ class Trainer(): logits = self.model(inputs) labels = inputs[..., 1:].contiguous().view(-1) - logits = logits[..., :-1, :].contiguous().view(-1, self.model.vocab_size) + logits = logits[..., :-1, :].contiguous().view( + -1, + self.model.module.vocab_size + if dist_utils.is_dist_available_and_initialized() + else self.model.vocab_size + ) loss = loss_fn(logits, labels) diff --git a/optimus/utils/dist_utils.py b/optimus/utils/dist_utils.py new file mode 100644 index 0000000..de159f6 --- /dev/null +++ b/optimus/utils/dist_utils.py @@ -0,0 +1,95 @@ +import os +import logging +import contextlib + +import torch.distributed as dist + + +logger = logging.getLogger(__name__) + + +def init_process_group(): + if 'RANK' not in os.environ or \ + 'WORLD_SIZE' not in os.environ or \ + 'MASTER_ADDR' not in os.environ or \ + 'MASTER_PORT' not in os.environ: + # at this point, we're pretty sure we are not in a distributed + # environment + pass + else: + dist.init_process_group('nccl') + + +def destroy_process_group(): + if 'RANK' not in os.environ or \ + 'WORLD_SIZE' not in os.environ or \ + 'MASTER_ADDR' not in os.environ or \ + 'MASTER_PORT' not in os.environ: + # at this point, we're pretty sure we are not in a distributed + # environment + pass + else: + dist.destroy_process_group() + + +def is_dist_available_and_initialized(): + if not dist.is_available(): + return False + + if not dist.is_initialized(): + return False + + return True + + +def get_rank(): + if not is_dist_available_and_initialized(): + return 0 + + return dist.get_rank() + + +def get_local_rank(): + try: + # this should be automatically detected when running with torchrun + local_rank = int(os.environ['LOCAL_RANK']) + except: + # if you run by hand, make sure to correctly set ONLY ONE OF: + # - LOCAL_RANK=<x> + # OR + # - CUDA_VISIBLE_DEVICES=<x> + # , where x is one of the GPU's available on the system + local_rank = 0 + return local_rank + + +def get_world_size(): + if not is_dist_available_and_initialized(): + return 1 + + return dist.get_world_size() + + +def is_main_rank(): + return get_rank() == 0 + + +@contextlib.contextmanager +def main_process_first(): + rank = get_rank() + if rank == 0: + try: + logger.info(f"Rank 0 start work") + yield + finally: + logger.info(f"Rank 0 done with work") + if is_dist_available_and_initialized(): + dist.barrier(device_ids=[get_local_rank()]) + else: + try: + logger.info(f"Rank {rank} waiting for rank 0 to process work") + if is_dist_available_and_initialized(): + dist.barrier(device_ids=[get_local_rank()]) + finally: + logger.info(f"Rank {rank} done waiting for rank 0") + yield diff --git a/optimus/utils/logging_utils.py b/optimus/utils/logging_utils.py new file mode 100644 index 0000000..f762da8 --- /dev/null +++ b/optimus/utils/logging_utils.py @@ -0,0 +1,13 @@ +import os +import logging + +import torch.distributed as dist + + +class FilterMainRankOnly(logging.Filter): + def filter(self, record): + if not dist.is_available(): + return True + if not dist.is_initialized(): + return True + return dist.get_rank() == 0 diff --git a/optimus/utils/setup_utils.py b/optimus/utils/setup_utils.py index d50ad19..e4eb65d 100644 --- a/optimus/utils/setup_utils.py +++ b/optimus/utils/setup_utils.py @@ -1,9 +1,11 @@ import logging +from torch.nn.parallel import DistributedDataParallel as DDP from transformers import AutoTokenizer from datasets import load_from_disk -from optimus.models.optimus import OptimusTransformer, OptimusConfig +from optimus.models.optimus import OptimusConfig, OptimusTransformer +from optimus.utils import dist_utils logger = logging.getLogger(__name__) @@ -71,6 +73,14 @@ def create_model(config_file, device): model = OptimusTransformer(config) model.to(device) + if dist_utils.is_dist_available_and_initialized(): + model = DDP( + model, + device_ids=[dist_utils.get_local_rank()], + output_device=dist_utils.get_local_rank(), + gradient_as_bucket_view=True, + ) + logger.info(model) _total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/training.py b/training.py index b822fb1..9a6feea 100644 --- a/training.py +++ b/training.py @@ -4,10 +4,12 @@ from pathlib import Path import fire import torch +import torch.distributed as dist from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from optimus.trainer import Trainer, TrainingArguments -from optimus.utils import setup_utils +from optimus.utils import dist_utils, logging_utils, setup_utils logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ def main( grad_clip_norm: float = 1.0, use_fp16: bool = True, seed: int = 42, + log_on_main_rank_only: bool = True, # model + tokenizer model_config_path: Path = Path('config.json'), @@ -63,6 +66,8 @@ def main( precision. Bfloat16 is used if available, otherwise fp16. seed (int): Seed used for reproducibility purposes. Each process will have its initial PyTorch seed set to `seed + process_rank`. + log_on_main_rank_only (bool): Whether output should only be printed by + the main rank (rank 0). model_config_path (Path): Path to a config file which describes the model to be trained. hf_tokenizer_name (str): HuggingFace tokenizer name. @@ -85,6 +90,13 @@ def main( `data_dir`. """ + # set up logging before anything else + logging.basicConfig( + level=os.environ.get('LOG_LEVEL', logging.INFO), + format='{asctime} - {filename} | {levelname} | {message}', + style='{', + ) + # create paths if isinstance(data_dir, str): data_dir = Path(data_dir) @@ -94,14 +106,6 @@ def main( else: dataset_dir = data_dir / dataset_dir - # set up logging before anything else - logging.basicConfig( - level=os.environ.get('LOG_LEVEL', logging.INFO), - format='{asctime} - {filename} | {levelname} | {message}', - style='{', - ) - - # create paths if isinstance(model_config_path, str): model_config_path = Path(model_config_path) @@ -139,6 +143,7 @@ def main( f"\t- gradient clipping norm: {grad_clip_norm}\n" f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"\t- seed: {seed}\n" + f"\t- only main rank logs: {log_on_main_rank_only}\n" f"\t- model config file: {model_config_path}\n" f"\t- huggingface tokenizer: {hf_tokenizer_name}\n" f"\t- training data directory: {str(data_dir)}\n" @@ -148,12 +153,26 @@ def main( f"\t- saved model directory: {str(save_dir)}\n" f"Please seek '--help' if you want to change any of these settings") + if log_on_main_rank_only: + logger.addFilter(logging_utils.FilterMainRankOnly()) + + # init distributed process group + dist_utils.init_process_group() + torch.manual_seed(seed + dist_utils.get_rank()) + # set device - device = f'cuda' + device = f'cuda:{dist_utils.get_local_rank()}' + + logger.info(f'Rank {dist_utils.get_rank()}, world size {dist_utils.get_world_size()}, device {device}') + + # set each process's GPU + torch.cuda.set_device(dist_utils.get_local_rank()) + torch.cuda.empty_cache() # load dataset and split into batches - dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len) - dataset.set_format('torch') + with dist_utils.main_process_first(): + dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len) + dataset.set_format('torch') # load tokenizer tokenizer = setup_utils.create_tokenizer(hf_tokenizer_name) @@ -161,13 +180,30 @@ def main( # create model and move to device model = setup_utils.create_model(model_config_path, device) - # create dataloaders + # create samplers + dataloaders + train_dist_sampler = DistributedSampler( + dataset['train'], + num_replicas=dist_utils.get_world_size(), + rank=dist_utils.get_rank(), + shuffle=True, + seed=seed, + drop_last=True, + ) + train_dataloader = DataLoader( dataset['train'], batch_size=batch_size, # per GPU num_workers=6, # allow pre-fetching data through multi-process pin_memory=True, # fast CPU-GPU transfer + sampler=train_dist_sampler, + ) + + eval_dist_sampler = DistributedSampler( + dataset['test'], + num_replicas=dist_utils.get_world_size(), + rank=dist_utils.get_rank(), shuffle=True, + seed=seed, drop_last=True, ) @@ -176,7 +212,7 @@ def main( batch_size=batch_size, # per GPU num_workers=6, # allow pre-fetching data through multi-process pin_memory=True, # fast CPU-GPU transfer - drop_last=True, + sampler=eval_dist_sampler, ) # create optimizer @@ -201,6 +237,7 @@ def main( # logging log_steps=100, + log_on_main_rank_only=log_on_main_rank_only, # core training seed=seed, @@ -239,6 +276,9 @@ def main( # save log data trainer.save_logs(log_dir) + # clean distributed process group + dist_utils.destroy_process_group() + if __name__=='__main__': fire.Fire(main) -- GitLab