diff --git a/optimus/dataloader.py b/optimus/dataloader.py
index 65869be320a434a6a2f87a6f3a9737d87a9c8d6f..5539156b555e6eae1834836e84b2eeaaab5c0d6a 100644
--- a/optimus/dataloader.py
+++ b/optimus/dataloader.py
@@ -1,5 +1,6 @@
 import time
 import random
+from typing import Tuple, Iterator, Iterable
 
 import torch
 from torch import Tensor
@@ -8,7 +9,7 @@ from torch.utils.data import Dataset
 from tokenizer import Tokenizer
 
 
-class _OptimusDLIter():
+class _OptimusDLIter(Iterator):
     def __init__(self, dl):
         """
         _OptimusDL iterator.
@@ -17,7 +18,7 @@ class _OptimusDLIter():
         self.dl = dl
         self.curr = 0
 
-    def __next__(self) -> (Tensor, Tensor):
+    def __next__(self) -> Tuple[Tensor, Tensor]:
         if self.curr > len(self.dl) - 1:
             raise StopIteration
 
@@ -28,8 +29,7 @@ class _OptimusDLIter():
 
         return x, y
 
-
-class _OptimusDL():
+class _OptimusDL(Iterable):
     def __init__(self, ds, tok, bs, seq_len, shuffle, device):
         """
         See 'OptimusDataLoader'.
@@ -80,7 +80,7 @@ class _OptimusDL():
         """
         self.device = device
 
-    def __iter__(self) -> _OptimusDLIter:
+    def __iter__(self) -> Iterator[_OptimusDLIter]:
         """
         Return an iterator over the dataloader object.
 
diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py
index 81d2829bf6632b00ba0938e30fcad202c3d6ef49..f84b0fe67aed189c85d8b2519661b70e325ebeb6 100644
--- a/optimus/datasets/tinystories.py
+++ b/optimus/datasets/tinystories.py
@@ -27,13 +27,13 @@ _EXTRACTED_FILES = {
 
 class TinyStoriesDataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         TinyStories dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py
index 50b170f68befb2ec1511a8b80b824a1361e220d1..daaecc91c96c5d2896ab45f0b6fd2f6044c0cefe 100644
--- a/optimus/datasets/wikitext103.py
+++ b/optimus/datasets/wikitext103.py
@@ -20,13 +20,13 @@ _EXTRACTED_FILES = {
 
 class WikiText103Dataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         WikiText103 dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/example_training.py b/optimus/example_training.py
index d92d38cca7db059391d7de70b61bcbc869f95af3..b3619a242703fd734541fa6f227f7dc398098c84 100644
--- a/optimus/example_training.py
+++ b/optimus/example_training.py
@@ -1,12 +1,6 @@
-import sys
-import time
-import math
-import argparse
-
 import fire
 import torch
 from torch import nn
-from torch.utils.data import Dataset
 
 from datasets.wikitext103 import WikiText103Dataset
 from tokenizer import Tokenizer
diff --git a/optimus/model.py b/optimus/model.py
index 108feafc6b9ec1547ed97a028d3e88051cc47ad5..5a432655c81579bcc42b0b001db2750bd8d0d735 100644
--- a/optimus/model.py
+++ b/optimus/model.py
@@ -53,7 +53,7 @@ class Attention(nn.Module):
 
     def forward(self,
                 x: torch.Tensor,
-                mask: torch.Tensor = None) -> torch.Tensor:
+                mask: torch.Tensor | None = None) -> torch.Tensor:
 
         bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
 
diff --git a/optimus/tokenizer.py b/optimus/tokenizer.py
index 170d4a5caf946894809fbf4a27a9e2f8b5dfe8ba..0a9ac95a337941ad3f732fcaa89d881f50eb8ca0 100644
--- a/optimus/tokenizer.py
+++ b/optimus/tokenizer.py
@@ -77,12 +77,12 @@ class Tokenizer():
 
         """
         assert type(input) is str
-        input = self.sp_model.encode_as_ids(input)
+        ids: List[int] = self.sp_model.encode_as_ids(input)
         if bos:
-            input = [self.bos_id] + input
+            ids = [self.bos_id] + ids
         if eos:
-            input = input + [self.eos_id]
-        return input
+            ids = ids + [self.eos_id]
+        return ids
 
     def encode_as_pieces(self, input: str) -> List[str]:
         """
diff --git a/optimus/trainer.py b/optimus/trainer.py
index 43c63714a1c3ae9336683f08627c227a046d6581..065d61a49f2a9c3c7ef57f1d51c935f0e2e0e97b 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -1,12 +1,13 @@
-import os
 import time
 import math
+from typing import Callable
 
 import torch
 import torch.nn as nn
+import torch.optim as optim
+from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 from dataloader import OptimusDataLoader
-from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 
 class Trainer():
@@ -14,8 +15,8 @@ class Trainer():
     def __init__(self,
                  dl: OptimusDataLoader,
                  model: nn.Module,
-                 criterion: callable,
-                 optimizer: torch.optim.Optimizer,
+                 criterion: Callable,
+                 optimizer: optim.Optimizer,
                  lr: float,
                  grad_acc_steps: int,
                  grad_clip_norm: float,