diff --git a/training.py b/training.py
index 1cd22ae96230bbc390d1e2506b1a4eef4a81f7c9..64cdf8f57d96c7bdc19e532c74aa67a6fa6035fe 100644
--- a/training.py
+++ b/training.py
@@ -89,9 +89,14 @@ def main(batch_size: int = 8,
     _total_params = sum(p.numel() for p in model.parameters())
     print(f"Number of model parameters: {_total_params}")
 
-    # define loss metric and optimizer
+    # define loss metric
     criterion = nn.CrossEntropyLoss()
-    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-9)
+
+    # define optimizer
+    # see [1] for a discussion on what the epsilon value should be for amp; 1e-7
+    # is a good default for both amp and normal training
+    # [1]: https://github.com/pytorch/pytorch/issues/26218
+    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-7)
 
     print("Starting training...")