From a092db0a63eeaf3d748bef7952845a7b6c622269 Mon Sep 17 00:00:00 2001
From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro>
Date: Mon, 22 Jan 2024 15:32:48 +0200
Subject: [PATCH] Fix a number of type issues

Fix problems with some types. This enables Python's static type checks
to correctly identify some issues before runtime.
---
 optimus/dataloader.py           | 10 +++++-----
 optimus/datasets/tinystories.py |  6 +++---
 optimus/datasets/wikitext103.py |  6 +++---
 optimus/example_training.py     |  6 ------
 optimus/model.py                |  2 +-
 optimus/tokenizer.py            |  8 ++++----
 optimus/trainer.py              |  9 +++++----
 7 files changed, 21 insertions(+), 26 deletions(-)

diff --git a/optimus/dataloader.py b/optimus/dataloader.py
index 65869be..5539156 100644
--- a/optimus/dataloader.py
+++ b/optimus/dataloader.py
@@ -1,5 +1,6 @@
 import time
 import random
+from typing import Tuple, Iterator, Iterable
 
 import torch
 from torch import Tensor
@@ -8,7 +9,7 @@ from torch.utils.data import Dataset
 from tokenizer import Tokenizer
 
 
-class _OptimusDLIter():
+class _OptimusDLIter(Iterator):
     def __init__(self, dl):
         """
         _OptimusDL iterator.
@@ -17,7 +18,7 @@ class _OptimusDLIter():
         self.dl = dl
         self.curr = 0
 
-    def __next__(self) -> (Tensor, Tensor):
+    def __next__(self) -> Tuple[Tensor, Tensor]:
         if self.curr > len(self.dl) - 1:
             raise StopIteration
 
@@ -28,8 +29,7 @@ class _OptimusDLIter():
 
         return x, y
 
-
-class _OptimusDL():
+class _OptimusDL(Iterable):
     def __init__(self, ds, tok, bs, seq_len, shuffle, device):
         """
         See 'OptimusDataLoader'.
@@ -80,7 +80,7 @@ class _OptimusDL():
         """
         self.device = device
 
-    def __iter__(self) -> _OptimusDLIter:
+    def __iter__(self) -> Iterator[_OptimusDLIter]:
         """
         Return an iterator over the dataloader object.
 
diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py
index 81d2829..f84b0fe 100644
--- a/optimus/datasets/tinystories.py
+++ b/optimus/datasets/tinystories.py
@@ -27,13 +27,13 @@ _EXTRACTED_FILES = {
 
 class TinyStoriesDataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         TinyStories dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py
index 50b170f..daaecc9 100644
--- a/optimus/datasets/wikitext103.py
+++ b/optimus/datasets/wikitext103.py
@@ -20,13 +20,13 @@ _EXTRACTED_FILES = {
 
 class WikiText103Dataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         WikiText103 dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/example_training.py b/optimus/example_training.py
index d92d38c..b3619a2 100644
--- a/optimus/example_training.py
+++ b/optimus/example_training.py
@@ -1,12 +1,6 @@
-import sys
-import time
-import math
-import argparse
-
 import fire
 import torch
 from torch import nn
-from torch.utils.data import Dataset
 
 from datasets.wikitext103 import WikiText103Dataset
 from tokenizer import Tokenizer
diff --git a/optimus/model.py b/optimus/model.py
index 108feaf..5a43265 100644
--- a/optimus/model.py
+++ b/optimus/model.py
@@ -53,7 +53,7 @@ class Attention(nn.Module):
 
     def forward(self,
                 x: torch.Tensor,
-                mask: torch.Tensor = None) -> torch.Tensor:
+                mask: torch.Tensor | None = None) -> torch.Tensor:
 
         bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
 
diff --git a/optimus/tokenizer.py b/optimus/tokenizer.py
index 170d4a5..0a9ac95 100644
--- a/optimus/tokenizer.py
+++ b/optimus/tokenizer.py
@@ -77,12 +77,12 @@ class Tokenizer():
 
         """
         assert type(input) is str
-        input = self.sp_model.encode_as_ids(input)
+        ids: List[int] = self.sp_model.encode_as_ids(input)
         if bos:
-            input = [self.bos_id] + input
+            ids = [self.bos_id] + ids
         if eos:
-            input = input + [self.eos_id]
-        return input
+            ids = ids + [self.eos_id]
+        return ids
 
     def encode_as_pieces(self, input: str) -> List[str]:
         """
diff --git a/optimus/trainer.py b/optimus/trainer.py
index 43c6371..065d61a 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -1,12 +1,13 @@
-import os
 import time
 import math
+from typing import Callable
 
 import torch
 import torch.nn as nn
+import torch.optim as optim
+from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 from dataloader import OptimusDataLoader
-from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 
 class Trainer():
@@ -14,8 +15,8 @@ class Trainer():
     def __init__(self,
                  dl: OptimusDataLoader,
                  model: nn.Module,
-                 criterion: callable,
-                 optimizer: torch.optim.Optimizer,
+                 criterion: Callable,
+                 optimizer: optim.Optimizer,
                  lr: float,
                  grad_acc_steps: int,
                  grad_clip_norm: float,
-- 
GitLab