diff --git a/training.py b/training.py index 1cd22ae96230bbc390d1e2506b1a4eef4a81f7c9..64cdf8f57d96c7bdc19e532c74aa67a6fa6035fe 100644 --- a/training.py +++ b/training.py @@ -89,9 +89,14 @@ def main(batch_size: int = 8, _total_params = sum(p.numel() for p in model.parameters()) print(f"Number of model parameters: {_total_params}") - # define loss metric and optimizer + # define loss metric criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-9) + + # define optimizer + # see [1] for a discussion on what the epsilon value should be for amp; 1e-7 + # is a good default for both amp and normal training + # [1]: https://github.com/pytorch/pytorch/issues/26218 + optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-7) print("Starting training...")