diff --git a/optimus/trainer.py b/optimus/trainer.py
index 585471482207267b7278a438f3b7efa0867ffcef..0c0c38119ff11dbc74a14e58d7e79ec1a859b9d6 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -22,6 +22,7 @@ class Trainer():
                  grad_acc_steps: int,
                  grad_clip_norm: float,
                  model_save_path: str,
+                 use_fp16: bool,
                  progress_bar: bool = True):
         """
         Trainer implementation for Optimus models.
@@ -40,6 +41,9 @@ class Trainer():
             grad_clip_norm (float): Gradient clipping norm value.
             model_save_path (str): The best model (based on validation loss) is
                 saved to the specified path.
+            use_fp16 (bool): Whether to train the model in 16-bit floating point
+                precision. If such hardware is not supported, a warning is
+                issued and normal 32-bit precision is used instead.
             progress_bar (bool): Whether to show a progress bar in console while
                 training. This is automatically disabled if output is a file,
                 however some stats are printed after finishing epochs. If False,
@@ -58,6 +62,10 @@ class Trainer():
 
         self.grad_clip_norm = grad_clip_norm
         self.model_save_path = model_save_path
+
+        self.use_fp16 = use_fp16
+        self.fp16_dtype = torch.float16
+
         self.progress_bar = progress_bar
 
     def fit(self, n_epochs: int) -> None:
@@ -76,6 +84,9 @@ class Trainer():
             epochs=n_epochs,
             steps_per_epoch=len(self.dl.train) // self.grad_acc_steps)
 
+        # scaler used for mixed precision fp16 training on GPU
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16)
+
         best_val_loss = float('inf')
 
         # progress bar for epochs
@@ -120,11 +131,16 @@ class Trainer():
             if self.progress_bar is True:
                 pb.update(i)
 
-            output = self.model(x)
-            loss = self.criterion(output.view(-1, len(self.dl.train.tok)),
-                                  y.reshape(-1))
-            loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation
-            loss.backward()
+            # automatic mixed precision training
+            with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
+                                         enabled=self.use_fp16):
+                output = self.model(x)
+                loss = self.criterion(output.view(-1, len(self.dl.train.tok)),
+                                      y.reshape(-1))
+
+                loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation
+
+            self.scaler.scale(loss).backward()
 
             total_loss += loss.item()
 
@@ -135,10 +151,12 @@ class Trainer():
                 # number of batches doesn't cleanly divide by grad_acc_steps
 
                 # gradient clipping
+                self.scaler.unscale_(self.optimizer)
                 nn.utils.clip_grad_norm_(self.model.parameters(),
                                          max_norm=self.grad_clip_norm)
 
-                self.optimizer.step()
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
                 self.optimizer.zero_grad()
 
                 self.scheduler.step()
@@ -173,9 +191,12 @@ class Trainer():
                 if self.progress_bar is True:
                     pb.update(i)
 
-                output = self.model(x)
-                loss = self.criterion(output.view(-1, len(self.dl.test.tok)),
-                                      y.reshape(-1))
+                with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
+                                             enabled=self.use_fp16):
+                    output = self.model(x)
+                    loss = self.criterion(output.view(-1, len(self.dl.test.tok)),
+                                          y.reshape(-1))
+
                 total_loss += loss.item()
 
                 self.mb.child.comment = f" | valid loss: {loss.item():.4f}"
diff --git a/training.py b/training.py
index 3efbfd21d7813266790faab6f789fa8d4d8a3d5a..1cd22ae96230bbc390d1e2506b1a4eef4a81f7c9 100644
--- a/training.py
+++ b/training.py
@@ -21,7 +21,7 @@ def main(batch_size: int = 8,
          n_layers: int = 6,
          n_heads: int = 8,
          dropout: float = 0.0,
-         device: str = 'cuda'):
+         use_fp16: bool = True):
     """
     Run the main training loop for the model.
 
@@ -40,8 +40,8 @@ def main(batch_size: int = 8,
         n_layers (int): Number of layers for the model.
         n_heads (int): Number of heads inside an attention layer for the model.
         dropout (float): Dropout to use for the model.
-        device (str): Device where to train the model. Viable options are 'cpu',
-            'cuda', 'cuda:2' etc.
+        use_fp16 (bool): Whether to train using floating-point 16-bits
+            precision.
 
     """
 
@@ -58,7 +58,7 @@ def main(batch_size: int = 8,
         f"\t- model layers: {n_layers}\n"
         f"\t- model attention heads: {n_heads}\n"
         f"\t- model dropout: {dropout}\n"
-        f"\t- training on device: {device}\n"
+        f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
         f"Please see '--help' if you want to change these settings")
 
     # load tokenizer
@@ -75,7 +75,7 @@ def main(batch_size: int = 8,
     dl = OptimusDataLoader(train_ds, test_ds, tok,
                            bs=batch_size,
                            seq_len=seq_len,
-                           device=device)
+                           device='cuda')
 
     # create model and move to device
     model = OptimusTransformer(len(tok),
@@ -84,7 +84,7 @@ def main(batch_size: int = 8,
                                n_heads=n_heads,
                                p_drop=dropout,
                                weight_tying=False)
-    model = model.to(device)
+    model = model.to('cuda')
 
     _total_params = sum(p.numel() for p in model.parameters())
     print(f"Number of model parameters: {_total_params}")
@@ -104,6 +104,7 @@ def main(batch_size: int = 8,
                       grad_acc_steps=grad_acc_steps,
                       grad_clip_norm=grad_clip_norm,
                       model_save_path=checkpoints_path,
+                      use_fp16=use_fp16,
                       progress_bar=True)
     trainer.fit(epochs)