Skip to content
Snippets Groups Projects
Unverified Commit e39773c1 authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Add DistributedDataParallel training

Add the most basic parallelism to the framework, through PyTorch's DDP.
Adjust the dataloaders to also use distributed samplers.

Add other goodies for distributed logging + distributed processing.
parent d96dcf92
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,8 @@ import torch.nn as nn ...@@ -8,6 +8,8 @@ import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from optimus.utils import dist_utils, logging_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -22,6 +24,8 @@ class TrainingArguments(): ...@@ -22,6 +24,8 @@ class TrainingArguments():
log_steps (int): Log training progress each number of steps. This log_steps (int): Log training progress each number of steps. This
considers number of updates, so gradient_accumulation_steps considers number of updates, so gradient_accumulation_steps
influences this. influences this.
log_on_main_rank_only (bool, defaults to True): Whether the main rank
should be the only one that logs training information.
show_progress (bool, defaults to True): Whether to show progress during show_progress (bool, defaults to True): Whether to show progress during
training. training.
seed (int): Seed used for reproducibility purposes. seed (int): Seed used for reproducibility purposes.
...@@ -57,10 +61,12 @@ class TrainingArguments(): ...@@ -57,10 +61,12 @@ class TrainingArguments():
checkpoints_dir: Path, checkpoints_dir: Path,
save_steps: int, save_steps: int,
save_limit: int, save_limit: int,
log_on_main_rank_only: bool = True,
show_progress: bool = True, show_progress: bool = True,
): ):
self.device = device self.device = device
self.log_steps = log_steps self.log_steps = log_steps
self.log_on_main_rank_only = log_on_main_rank_only
self.show_progress = show_progress self.show_progress = show_progress
self.seed = seed self.seed = seed
self.optimizer = optimizer self.optimizer = optimizer
...@@ -106,6 +112,9 @@ class Trainer(): ...@@ -106,6 +112,9 @@ class Trainer():
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.tokenizer = tokenizer self.tokenizer = tokenizer
if self.args.log_on_main_rank_only:
logger.addFilter(logging_utils.FilterMainRankOnly())
def train(self) -> None: def train(self) -> None:
""" """
Training loop of the trainer. Training loop of the trainer.
...@@ -114,13 +123,13 @@ class Trainer(): ...@@ -114,13 +123,13 @@ class Trainer():
num_examples = len(self.train_dataloader) * self.args.per_device_batch_size num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
max_steps = self.args.num_train_epochs * num_update_steps_per_epoch max_steps = self.args.num_train_epochs * num_update_steps_per_epoch
global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * 1 global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * dist_utils.get_world_size()
fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
loss_fn = torch.nn.CrossEntropyLoss() loss_fn = torch.nn.CrossEntropyLoss()
self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress)) self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress) or dist_utils.get_rank() != 0)
# scaler used for mixed precision fp16 training on GPU # scaler used for mixed precision fp16 training on GPU
self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16) self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
...@@ -152,7 +161,7 @@ class Trainer(): ...@@ -152,7 +161,7 @@ class Trainer():
logits = self.model(inputs) logits = self.model(inputs)
labels = inputs[..., 1:].contiguous().view(-1) labels = inputs[..., 1:].contiguous().view(-1)
logits = logits[..., :-1, :].contiguous().view(-1, self.model.vocab_size) logits = logits[..., :-1, :].contiguous().view(-1, self.model.module.vocab_size)
loss = loss_fn(logits, labels) loss = loss_fn(logits, labels)
......
import os
import logging
import contextlib
import torch.distributed as dist
logger = logging.getLogger(__name__)
def init_process_group():
if 'RANK' not in os.environ or \
'WORLD_SIZE' not in os.environ or \
'MASTER_ADDR' not in os.environ or \
'MASTER_PORT' not in os.environ:
# at this point, we're pretty sure we are not in a distributed
# environment
pass
else:
dist.init_process_group('nccl')
def destroy_process_group():
if 'RANK' not in os.environ or \
'WORLD_SIZE' not in os.environ or \
'MASTER_ADDR' not in os.environ or \
'MASTER_PORT' not in os.environ:
# at this point, we're pretty sure we are not in a distributed
# environment
pass
else:
dist.destroy_process_group()
def is_dist_available_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_available_and_initialized():
return 0
return dist.get_rank()
def get_local_rank():
try:
# this should be automatically detected when running with torchrun
local_rank = int(os.environ['LOCAL_RANK'])
except:
# if you run by hand, make sure to correctly set ONLY ONE OF:
# - LOCAL_RANK=<x>
# OR
# - CUDA_VISIBLE_DEVICES=<x>
# , where x is one of the GPU's available on the system
local_rank = 0
return local_rank
def get_world_size():
if not is_dist_available_and_initialized():
return 1
return dist.get_world_size()
def is_main_rank():
return get_rank() == 0
@contextlib.contextmanager
def main_process_first():
rank = get_rank()
if rank == 0:
try:
logger.info(f"Rank 0 start work")
yield
finally:
logger.info(f"Rank 0 done with work")
if is_dist_available_and_initialized():
dist.barrier(device_ids=[get_local_rank()])
else:
try:
logger.info(f"Rank {rank} waiting for rank 0 to process work")
if is_dist_available_and_initialized():
dist.barrier(device_ids=[get_local_rank()])
finally:
logger.info(f"Rank {rank} done waiting for rank 0")
yield
import os
import logging
import torch.distributed as dist
class FilterMainRankOnly(logging.Filter):
def filter(self, record):
if not dist.is_available():
return True
if not dist.is_initialized():
return True
return dist.get_rank() == 0
import logging import logging
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoTokenizer from transformers import AutoTokenizer
from datasets import load_from_disk from datasets import load_from_disk
from optimus.models.optimus import OptimusTransformer, OptimusConfig from optimus.models.optimus import OptimusConfig, OptimusTransformer
from optimus.utils import dist_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -71,6 +73,13 @@ def create_model(config_file, device): ...@@ -71,6 +73,13 @@ def create_model(config_file, device):
model = OptimusTransformer(config) model = OptimusTransformer(config)
model.to(device) model.to(device)
model = DDP(
model,
device_ids=[dist_utils.get_local_rank()],
output_device=dist_utils.get_local_rank(),
gradient_as_bucket_view=True,
)
logger.info(model) logger.info(model)
_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) _total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
......
...@@ -4,10 +4,12 @@ from pathlib import Path ...@@ -4,10 +4,12 @@ from pathlib import Path
import fire import fire
import torch import torch
import torch.distributed as dist
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from optimus.trainer import Trainer, TrainingArguments from optimus.trainer import Trainer, TrainingArguments
from optimus.utils import setup_utils from optimus.utils import dist_utils, logging_utils, setup_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -25,6 +27,7 @@ def main( ...@@ -25,6 +27,7 @@ def main(
grad_clip_norm: float = 1.0, grad_clip_norm: float = 1.0,
use_fp16: bool = True, use_fp16: bool = True,
seed: int = 42, seed: int = 42,
log_on_main_rank_only: bool = True,
# model + tokenizer # model + tokenizer
model_config_path: Path = Path('config.json'), model_config_path: Path = Path('config.json'),
...@@ -63,6 +66,8 @@ def main( ...@@ -63,6 +66,8 @@ def main(
precision. Bfloat16 is used if available, otherwise fp16. precision. Bfloat16 is used if available, otherwise fp16.
seed (int): Seed used for reproducibility purposes. Each process will seed (int): Seed used for reproducibility purposes. Each process will
have its initial PyTorch seed set to `seed + process_rank`. have its initial PyTorch seed set to `seed + process_rank`.
log_on_main_rank_only (bool): Whether output should only be printed by
the main rank (rank 0).
model_config_path (Path): Path to a config file which describes the model_config_path (Path): Path to a config file which describes the
model to be trained. model to be trained.
hf_tokenizer_name (str): HuggingFace tokenizer name. hf_tokenizer_name (str): HuggingFace tokenizer name.
...@@ -85,6 +90,13 @@ def main( ...@@ -85,6 +90,13 @@ def main(
`data_dir`. `data_dir`.
""" """
# set up logging before anything else
logging.basicConfig(
level=os.environ.get('LOG_LEVEL', logging.INFO),
format='{asctime} - {filename} | {levelname} | {message}',
style='{',
)
# create paths # create paths
if isinstance(data_dir, str): if isinstance(data_dir, str):
data_dir = Path(data_dir) data_dir = Path(data_dir)
...@@ -94,14 +106,6 @@ def main( ...@@ -94,14 +106,6 @@ def main(
else: else:
dataset_dir = data_dir / dataset_dir dataset_dir = data_dir / dataset_dir
# set up logging before anything else
logging.basicConfig(
level=os.environ.get('LOG_LEVEL', logging.INFO),
format='{asctime} - {filename} | {levelname} | {message}',
style='{',
)
# create paths
if isinstance(model_config_path, str): if isinstance(model_config_path, str):
model_config_path = Path(model_config_path) model_config_path = Path(model_config_path)
...@@ -139,6 +143,7 @@ def main( ...@@ -139,6 +143,7 @@ def main(
f"\t- gradient clipping norm: {grad_clip_norm}\n" f"\t- gradient clipping norm: {grad_clip_norm}\n"
f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
f"\t- seed: {seed}\n" f"\t- seed: {seed}\n"
f"\t- only main rank logs: {log_on_main_rank_only}\n"
f"\t- model config file: {model_config_path}\n" f"\t- model config file: {model_config_path}\n"
f"\t- huggingface tokenizer: {hf_tokenizer_name}\n" f"\t- huggingface tokenizer: {hf_tokenizer_name}\n"
f"\t- training data directory: {str(data_dir)}\n" f"\t- training data directory: {str(data_dir)}\n"
...@@ -148,12 +153,26 @@ def main( ...@@ -148,12 +153,26 @@ def main(
f"\t- saved model directory: {str(save_dir)}\n" f"\t- saved model directory: {str(save_dir)}\n"
f"Please seek '--help' if you want to change any of these settings") f"Please seek '--help' if you want to change any of these settings")
if log_on_main_rank_only:
logger.addFilter(logging_utils.FilterMainRankOnly())
# init distributed process group
dist_utils.init_process_group()
torch.manual_seed(seed + dist_utils.get_rank())
# set device # set device
device = f'cuda' device = f'cuda:{dist_utils.get_local_rank()}'
logger.info(f'Rank {dist_utils.get_rank()}, world size {dist_utils.get_world_size()}, device {device}')
# set each process's GPU
torch.cuda.set_device(dist_utils.get_local_rank())
torch.cuda.empty_cache()
# load dataset and split into batches # load dataset and split into batches
dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len) with dist_utils.main_process_first():
dataset.set_format('torch') dataset = setup_utils.load_and_chunk_dataset(str(dataset_dir), seq_len)
dataset.set_format('torch')
# load tokenizer # load tokenizer
tokenizer = setup_utils.create_tokenizer(hf_tokenizer_name) tokenizer = setup_utils.create_tokenizer(hf_tokenizer_name)
...@@ -161,13 +180,30 @@ def main( ...@@ -161,13 +180,30 @@ def main(
# create model and move to device # create model and move to device
model = setup_utils.create_model(model_config_path, device) model = setup_utils.create_model(model_config_path, device)
# create dataloaders # create samplers + dataloaders
train_dist_sampler = DistributedSampler(
dataset['train'],
num_replicas=dist_utils.get_world_size(),
rank=dist_utils.get_rank(),
shuffle=True,
seed=seed,
drop_last=True,
)
train_dataloader = DataLoader( train_dataloader = DataLoader(
dataset['train'], dataset['train'],
batch_size=batch_size, # per GPU batch_size=batch_size, # per GPU
num_workers=6, # allow pre-fetching data through multi-process num_workers=6, # allow pre-fetching data through multi-process
pin_memory=True, # fast CPU-GPU transfer pin_memory=True, # fast CPU-GPU transfer
sampler=train_dist_sampler,
)
eval_dist_sampler = DistributedSampler(
dataset['test'],
num_replicas=dist_utils.get_world_size(),
rank=dist_utils.get_rank(),
shuffle=True, shuffle=True,
seed=seed,
drop_last=True, drop_last=True,
) )
...@@ -176,7 +212,7 @@ def main( ...@@ -176,7 +212,7 @@ def main(
batch_size=batch_size, # per GPU batch_size=batch_size, # per GPU
num_workers=6, # allow pre-fetching data through multi-process num_workers=6, # allow pre-fetching data through multi-process
pin_memory=True, # fast CPU-GPU transfer pin_memory=True, # fast CPU-GPU transfer
drop_last=True, sampler=eval_dist_sampler,
) )
# create optimizer # create optimizer
...@@ -201,6 +237,7 @@ def main( ...@@ -201,6 +237,7 @@ def main(
# logging # logging
log_steps=100, log_steps=100,
log_on_main_rank_only=log_on_main_rank_only,
# core training # core training
seed=seed, seed=seed,
...@@ -239,6 +276,9 @@ def main( ...@@ -239,6 +276,9 @@ def main(
# save log data # save log data
trainer.save_logs(log_dir) trainer.save_logs(log_dir)
# clean distributed process group
dist_utils.destroy_process_group()
if __name__=='__main__': if __name__=='__main__':
fire.Fire(main) fire.Fire(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment