Skip to content
Snippets Groups Projects
Commit 6d3a3004 authored by Vlad-Andrei BĂDOIU (78692)'s avatar Vlad-Andrei BĂDOIU (78692)
Browse files

Introduce the module for distributed training

parent 8579fc15
No related branches found
No related tags found
1 merge request!21Draft: Add support for data parallelism on a single node
from .distributon import Distributon
\ No newline at end of file
from typing import Tuple, Iterator
from typing import (
Optional,
Union,
Generator,
Any,
Dict,
Callable
)
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, DistributedSampler
def _set_sampler_epoch(dataloader: object, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
# cannot use a set because samplers might be unhashable: use a dict based on the id to drop duplicates
objects: Dict[int, Any] = {}
# check dataloader.sampler
if (sampler := getattr(dataloader, "sampler", None)) is not None:
objects[id(sampler)] = sampler
# check dataloader.batch_sampler.sampler
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
sampler := getattr(batch_sampler, "sampler", None)
) is not None:
objects[id(sampler)] = sampler
for obj in objects.values():
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)
def apply_to_collection(
data: Any,
function: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
if data.__class__ is list:
return [function(x, *args, **kwargs) for x in data]
if data.__class__ is tuple:
return tuple(function(x, *args, **kwargs) for x in data)
if data.__class__ is dict:
return {k: function(v, *args, **kwargs) for k, v in data.items()}
assert False
def move_data_to_device(batch: Any, device) -> Any:
if isinstance(device, str):
device = torch.device(device)
def batch_to(data: Any) -> Any:
kwargs = {}
if isinstance(data, Tensor) and isinstance(device, torch.device):
kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs)
if data_output is not None:
return data_output
return data
return apply_to_collection(batch, function=batch_to)
class _DistributonDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
self._num_iter_calls = 0
def __len__(self) -> int:
return len(self._dataloader)
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled.By doing this the user does not have to call `.set_epoch()` on the sampler.
_set_sampler_epoch(self._dataloader, self._num_iter_calls)
self._num_iter_calls += 1
if self._device is None:
yield from iter(self._dataloader)
else:
for item in self._dataloader:
yield move_data_to_device(item, self._device)
def collate_fn(
batch,
max_length: int = 1024,
) -> Tuple[Tensor, Tensor]:
x = torch.zeros(len(batch), max_length, dtype=torch.long)
y = torch.zeros(len(batch), max_length, dtype=torch.long)
for i, encoding in enumerate(batch):
seq_length = min(len(encoding) - 1, max_length)
x[i, :seq_length] = encoding[:seq_length]
y[i, :seq_length] = encoding[1 : seq_length + 1]
return x, y
def build_dataloader(
dataset: Dataset,
bs: int,
seq_len: int,
device: str,
distributed: bool,
shuffle: bool = True,
):
# TODO: In the future we will need to use our own IterableDataset
# wit the logic for the distributed computation
sampler = None
if distributed:
sampler = DistributedSampler(dataset, shuffle=shuffle)
dataloader = DataLoader(
dataset,
batch_size=bs,
sampler=sampler,
collate_fn=collate_fn,
drop_last=True,
)
if distributed:
dataloader = _DistributonDataLoader(dataloader=dataloader, device=device)
return dataloader
\ No newline at end of file
from functools import partial
from typing import (
Any,
)
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import Tensor
from .strategy import DataParalelism
from .environment import SingleNodeEnvironment
from .dataloader import _DistributonDataLoader
class Distributon:
"""
We have three main components: the launcher, the strategy and the environment.
"""
def __init__(self, devices):
self._num_nodes = 1
self.num_processes = len(devices)
self._strategy = DataParalelism(devices, SingleNodeEnvironment())
def launch(self, *args: Any, **kwargs: Any):
"""Launch and initialize the processes needed for the distributed training. """
to_run = partial(self._strategy.setup_environment)
self._strategy.launcher.launch(to_run, args, kwargs)
def setup_optimizer(self, optimizer):
""" Returns wrapper over optimizer """
optimizer = self._strategy.setup_optimizer(optimizer)
# TODO: we need to add a wrapper over the optimizer
# to delegate optimizer step calls to the strategy
return optimizer
def setup_model(self, model: nn.Module) -> Any:
""" Returns wrapper over model """
# Move the model to the device
self._strategy.model_to_device(model=model)
model = self._strategy.setup_module(model)
return model
def setup_dataloaders(self, dataloader: DataLoader):
""" Returns wrapper over dataloadewr """
distributon_dataloader = _DistributonDataLoader(dataloader=dataloader, device=self._strategy.root_device)
return distributon_dataloader
def backward(self, tensor: Tensor, *args: Any, **kwargs: Any) -> None:
r"""Forwards backward-calls to the precision plugin."""
# TODO: Use mixed precission for this
tensor.backward(*args, **kwargs)
import os
from abc import abstractmethod
from typing_extensions import override
class Environment:
@abstractmethod
def local_rank(self) -> int:
"""The rank of the currently running process inside of the current node."""
@abstractmethod
def main_address(self) -> str:
"""The IP address of the master node."""
@abstractmethod
def main_port(self) -> str:
"""The port used for communication with the master node."""
@abstractmethod
def set_world_size(self, size: int) -> None:
"""Sets the world size."""
@property
@abstractmethod
def creates_processes_externally(self) -> bool:
"""Whether the environment creates the subprocesses or not."""
class SingleNodeEnvironment(Environment):
@override
def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))
@override
def main_address(self) -> str:
return os.environ.get("MASTER_ADDR", "127.0.0.1")
@override
def main_port(self) -> str:
return os.environ.get("MASTER_PORT", "12355")
@override
def set_world_size(self, size: int) -> None:
self.world_size = size
@property
@override
def creates_processes_externally(self) -> bool:
"""Returns whether the cluster creates the processes or not.
If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
process launcher/job scheduler and Lightning will not launch new processes.
"""
return "LOCAL_RANK" in os.environ
import logging
import os
import subprocess
import sys
from typing import (
Any,
Callable,
Sequence,
List
)
from random import getstate as python_get_rng_state
from typing_extensions import override
import torch
from .environment import Environment
def _basic_subprocess_cmd() -> Sequence[str]:
import __main__
if __main__.__spec__ is None: # pragma: no-cover
return [sys.executable, os.path.abspath(sys.argv[0])] + sys.argv[1:]
return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
def _num_cpus_available() -> int:
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))
cpu_count = os.cpu_count()
return 1 if cpu_count is None else cpu_count
def _suggested_max_num_threads(num_processes: int = 1) -> int:
if num_processes < 1:
raise ValueError(f"`num_processes` should be >= 1, got {num_processes}.")
return max(1, _num_cpus_available() // num_processes)
class Launcher:
pass
class SubprocessLauncher(Launcher):
def __init__(
self,
environment: Environment,
num_processes: int,
num_nodes: int,
) -> None:
self.num_processes = num_processes
self.num_nodes = num_nodes
self.environment = environment
self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
# Launch processes
print("[INFO] Creating children processes")
if not self.environment.creates_processes_externally:
self._call_children_scripts()
# TODO: launch process observer over self.procs
# Set number of threads to use
if "OMP_NUM_THREADS" not in os.environ:
num_threads = _suggested_max_num_threads(self.num_processes)
torch.set_num_threads(num_threads)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
return function()
def _call_children_scripts(self) -> None:
# Sanity check that we're the master process and
# only starting them now
assert(len(self.procs) == 0)
assert(self.environment.local_rank() == 0)
# DDP Environment variables
os.environ["MASTER_ADDR"] = self.environment.main_address()
os.environ["MASTER_PORT"] = str(self.environment.main_port())
print('[INFO] Launching {} processes'.format(self.num_processes))
for local_rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"
# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]
# start process
command = _basic_subprocess_cmd()
stdout_file = open(f'node_{self.environment.local_rank()}.out', 'w')
stderr_file = open(f'node_{self.environment.local_rank()}.err', 'w')
proc = subprocess.Popen(command, env=env_copy, stdout=stdout_file, stderr=stderr_file)
self.procs.append(proc)
\ No newline at end of file
import os
from contextlib import nullcontext
from abc import abstractmethod
from typing import (
Optional,
List,
Any
)
from typing_extensions import override
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parallel.distributed import DistributedDataParallel
from .environment import Environment
from .launcher import SubprocessLauncher
class Strategy:
"""Strategies for distributed training."""
def __init__(self, environment):
self.environment = environment
@abstractmethod
def setup_environment(self):
"""Setup the environment for the parallelism strategy."""
@abstractmethod
def setup_module(self, module):
"""Setups the model."""
@abstractmethod
def setup_optimizer(self, optimizer):
"""Setups the model."""
@abstractmethod
def model_to_device(self, model):
"""Moves the model to device."""
@abstractmethod
def backwards(self, tensor: Tensor, *args: Any, **kwargs: Any):
"""Performs a backwards pass."""
@property
def local_rank(self) -> int:
"""Returns the local rank"""
return self.environment.local_rank()
class DataParalelism(Strategy):
"""Implementation for Data Paralellism strategy on a single node"""
def __init__(self, devices: List[torch.device] = None, environment: Environment = None):
super().__init__(environment)
self.environment = environment
self.devices = devices
self.num_processes = len(devices)
# Set the device launcher.
self.launcher = SubprocessLauncher(self.environment, self.num_processes, 1)
@property
def root_device(self) -> torch.device:
assert self.devices is not None
return self.devices[self.local_rank]
@override
def setup_environment(self):
# Setup the CUDA device
torch.cuda.set_device(self.root_device)
# Setup distributed (currently only 1 node)
self.environment.set_world_size(1 * self.num_processes)
os.environ["MASTER_ADDR"] = self.environment.main_address()
os.environ["MASTER_PORT"] = str(self.environment.main_port())
print("[INFO] Initializing process group with {} and {}".format(self.local_rank, self.environment.world_size))
torch.distributed.init_process_group("nccl", rank=self.local_rank, world_size=self.environment.world_size)
@override
def setup_module(self, module: nn.Module):
assert module is not None
device_ids = [self.local_rank]
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids)
@override
def setup_optimizer(self, optimizer):
return optimizer
@override
def model_to_device(self, model: nn.Module) -> None:
model.to(self.root_device)
@override
def backwards(self, tensor: Tensor, *args: Any, **kwargs: Any):
tensor.backward(*args, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment