diff --git a/optimus/trainer.py b/optimus/trainer.py index 585471482207267b7278a438f3b7efa0867ffcef..0c0c38119ff11dbc74a14e58d7e79ec1a859b9d6 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -22,6 +22,7 @@ class Trainer(): grad_acc_steps: int, grad_clip_norm: float, model_save_path: str, + use_fp16: bool, progress_bar: bool = True): """ Trainer implementation for Optimus models. @@ -40,6 +41,9 @@ class Trainer(): grad_clip_norm (float): Gradient clipping norm value. model_save_path (str): The best model (based on validation loss) is saved to the specified path. + use_fp16 (bool): Whether to train the model in 16-bit floating point + precision. If such hardware is not supported, a warning is + issued and normal 32-bit precision is used instead. progress_bar (bool): Whether to show a progress bar in console while training. This is automatically disabled if output is a file, however some stats are printed after finishing epochs. If False, @@ -58,6 +62,10 @@ class Trainer(): self.grad_clip_norm = grad_clip_norm self.model_save_path = model_save_path + + self.use_fp16 = use_fp16 + self.fp16_dtype = torch.float16 + self.progress_bar = progress_bar def fit(self, n_epochs: int) -> None: @@ -76,6 +84,9 @@ class Trainer(): epochs=n_epochs, steps_per_epoch=len(self.dl.train) // self.grad_acc_steps) + # scaler used for mixed precision fp16 training on GPU + self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16) + best_val_loss = float('inf') # progress bar for epochs @@ -120,11 +131,16 @@ class Trainer(): if self.progress_bar is True: pb.update(i) - output = self.model(x) - loss = self.criterion(output.view(-1, len(self.dl.train.tok)), - y.reshape(-1)) - loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation - loss.backward() + # automatic mixed precision training + with torch.cuda.amp.autocast(dtype=self.fp16_dtype, + enabled=self.use_fp16): + output = self.model(x) + loss = self.criterion(output.view(-1, len(self.dl.train.tok)), + y.reshape(-1)) + + loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation + + self.scaler.scale(loss).backward() total_loss += loss.item() @@ -135,10 +151,12 @@ class Trainer(): # number of batches doesn't cleanly divide by grad_acc_steps # gradient clipping + self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) - self.optimizer.step() + self.scaler.step(self.optimizer) + self.scaler.update() self.optimizer.zero_grad() self.scheduler.step() @@ -173,9 +191,12 @@ class Trainer(): if self.progress_bar is True: pb.update(i) - output = self.model(x) - loss = self.criterion(output.view(-1, len(self.dl.test.tok)), - y.reshape(-1)) + with torch.cuda.amp.autocast(dtype=self.fp16_dtype, + enabled=self.use_fp16): + output = self.model(x) + loss = self.criterion(output.view(-1, len(self.dl.test.tok)), + y.reshape(-1)) + total_loss += loss.item() self.mb.child.comment = f" | valid loss: {loss.item():.4f}" diff --git a/training.py b/training.py index 3efbfd21d7813266790faab6f789fa8d4d8a3d5a..1cd22ae96230bbc390d1e2506b1a4eef4a81f7c9 100644 --- a/training.py +++ b/training.py @@ -21,7 +21,7 @@ def main(batch_size: int = 8, n_layers: int = 6, n_heads: int = 8, dropout: float = 0.0, - device: str = 'cuda'): + use_fp16: bool = True): """ Run the main training loop for the model. @@ -40,8 +40,8 @@ def main(batch_size: int = 8, n_layers (int): Number of layers for the model. n_heads (int): Number of heads inside an attention layer for the model. dropout (float): Dropout to use for the model. - device (str): Device where to train the model. Viable options are 'cpu', - 'cuda', 'cuda:2' etc. + use_fp16 (bool): Whether to train using floating-point 16-bits + precision. """ @@ -58,7 +58,7 @@ def main(batch_size: int = 8, f"\t- model layers: {n_layers}\n" f"\t- model attention heads: {n_heads}\n" f"\t- model dropout: {dropout}\n" - f"\t- training on device: {device}\n" + f"\t- 16-bit floating-point training (fp16): {use_fp16}\n" f"Please see '--help' if you want to change these settings") # load tokenizer @@ -75,7 +75,7 @@ def main(batch_size: int = 8, dl = OptimusDataLoader(train_ds, test_ds, tok, bs=batch_size, seq_len=seq_len, - device=device) + device='cuda') # create model and move to device model = OptimusTransformer(len(tok), @@ -84,7 +84,7 @@ def main(batch_size: int = 8, n_heads=n_heads, p_drop=dropout, weight_tying=False) - model = model.to(device) + model = model.to('cuda') _total_params = sum(p.numel() for p in model.parameters()) print(f"Number of model parameters: {_total_params}") @@ -104,6 +104,7 @@ def main(batch_size: int = 8, grad_acc_steps=grad_acc_steps, grad_clip_norm=grad_clip_norm, model_save_path=checkpoints_path, + use_fp16=use_fp16, progress_bar=True) trainer.fit(epochs)