From a092db0a63eeaf3d748bef7952845a7b6c622269 Mon Sep 17 00:00:00 2001 From: Alexandru Gherghescu <gherghescu_alex1@yahoo.ro> Date: Mon, 22 Jan 2024 15:32:48 +0200 Subject: [PATCH] Fix a number of type issues Fix problems with some types. This enables Python's static type checks to correctly identify some issues before runtime. --- optimus/dataloader.py | 10 +++++----- optimus/datasets/tinystories.py | 6 +++--- optimus/datasets/wikitext103.py | 6 +++--- optimus/example_training.py | 6 ------ optimus/model.py | 2 +- optimus/tokenizer.py | 8 ++++---- optimus/trainer.py | 9 +++++---- 7 files changed, 21 insertions(+), 26 deletions(-) diff --git a/optimus/dataloader.py b/optimus/dataloader.py index 65869be..5539156 100644 --- a/optimus/dataloader.py +++ b/optimus/dataloader.py @@ -1,5 +1,6 @@ import time import random +from typing import Tuple, Iterator, Iterable import torch from torch import Tensor @@ -8,7 +9,7 @@ from torch.utils.data import Dataset from tokenizer import Tokenizer -class _OptimusDLIter(): +class _OptimusDLIter(Iterator): def __init__(self, dl): """ _OptimusDL iterator. @@ -17,7 +18,7 @@ class _OptimusDLIter(): self.dl = dl self.curr = 0 - def __next__(self) -> (Tensor, Tensor): + def __next__(self) -> Tuple[Tensor, Tensor]: if self.curr > len(self.dl) - 1: raise StopIteration @@ -28,8 +29,7 @@ class _OptimusDLIter(): return x, y - -class _OptimusDL(): +class _OptimusDL(Iterable): def __init__(self, ds, tok, bs, seq_len, shuffle, device): """ See 'OptimusDataLoader'. @@ -80,7 +80,7 @@ class _OptimusDL(): """ self.device = device - def __iter__(self) -> _OptimusDLIter: + def __iter__(self) -> Iterator[_OptimusDLIter]: """ Return an iterator over the dataloader object. diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py index 81d2829..f84b0fe 100644 --- a/optimus/datasets/tinystories.py +++ b/optimus/datasets/tinystories.py @@ -27,13 +27,13 @@ _EXTRACTED_FILES = { class TinyStoriesDataset(Dataset): - def __init__(self, root: str = None, split: str = 'train'): + def __init__(self, root: str | None = None, split: str = 'train'): """ TinyStories dataset. Args: - root (str): Directory where the dataset is saved. Defaults to - os.path.expanduser('~/.cache/optimus/'). + root (str | None): Directory where the dataset is saved. Defaults to + os.path.expanduser('~/.cache/optimus/') if None. split (str): Split to be returned. Can be 'train', 'test' or 'valid'. diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py index 50b170f..daaecc9 100644 --- a/optimus/datasets/wikitext103.py +++ b/optimus/datasets/wikitext103.py @@ -20,13 +20,13 @@ _EXTRACTED_FILES = { class WikiText103Dataset(Dataset): - def __init__(self, root: str = None, split: str = 'train'): + def __init__(self, root: str | None = None, split: str = 'train'): """ WikiText103 dataset. Args: - root (str): Directory where the dataset is saved. Defaults to - os.path.expanduser('~/.cache/optimus/'). + root (str | None): Directory where the dataset is saved. Defaults to + os.path.expanduser('~/.cache/optimus/') if None. split (str): Split to be returned. Can be 'train', 'test' or 'valid'. diff --git a/optimus/example_training.py b/optimus/example_training.py index d92d38c..b3619a2 100644 --- a/optimus/example_training.py +++ b/optimus/example_training.py @@ -1,12 +1,6 @@ -import sys -import time -import math -import argparse - import fire import torch from torch import nn -from torch.utils.data import Dataset from datasets.wikitext103 import WikiText103Dataset from tokenizer import Tokenizer diff --git a/optimus/model.py b/optimus/model.py index 108feaf..5a43265 100644 --- a/optimus/model.py +++ b/optimus/model.py @@ -53,7 +53,7 @@ class Attention(nn.Module): def forward(self, x: torch.Tensor, - mask: torch.Tensor = None) -> torch.Tensor: + mask: torch.Tensor | None = None) -> torch.Tensor: bs, seq_len, _ = x.shape # (batch_size, seq_len, dim) diff --git a/optimus/tokenizer.py b/optimus/tokenizer.py index 170d4a5..0a9ac95 100644 --- a/optimus/tokenizer.py +++ b/optimus/tokenizer.py @@ -77,12 +77,12 @@ class Tokenizer(): """ assert type(input) is str - input = self.sp_model.encode_as_ids(input) + ids: List[int] = self.sp_model.encode_as_ids(input) if bos: - input = [self.bos_id] + input + ids = [self.bos_id] + ids if eos: - input = input + [self.eos_id] - return input + ids = ids + [self.eos_id] + return ids def encode_as_pieces(self, input: str) -> List[str]: """ diff --git a/optimus/trainer.py b/optimus/trainer.py index 43c6371..065d61a 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -1,12 +1,13 @@ -import os import time import math +from typing import Callable import torch import torch.nn as nn +import torch.optim as optim +from fastprogress.fastprogress import master_bar, progress_bar, format_time from dataloader import OptimusDataLoader -from fastprogress.fastprogress import master_bar, progress_bar, format_time class Trainer(): @@ -14,8 +15,8 @@ class Trainer(): def __init__(self, dl: OptimusDataLoader, model: nn.Module, - criterion: callable, - optimizer: torch.optim.Optimizer, + criterion: Callable, + optimizer: optim.Optimizer, lr: float, grad_acc_steps: int, grad_clip_norm: float, -- GitLab