diff --git a/optimus/dataloader.py b/optimus/dataloader.py index 65869be320a434a6a2f87a6f3a9737d87a9c8d6f..5539156b555e6eae1834836e84b2eeaaab5c0d6a 100644 --- a/optimus/dataloader.py +++ b/optimus/dataloader.py @@ -1,5 +1,6 @@ import time import random +from typing import Tuple, Iterator, Iterable import torch from torch import Tensor @@ -8,7 +9,7 @@ from torch.utils.data import Dataset from tokenizer import Tokenizer -class _OptimusDLIter(): +class _OptimusDLIter(Iterator): def __init__(self, dl): """ _OptimusDL iterator. @@ -17,7 +18,7 @@ class _OptimusDLIter(): self.dl = dl self.curr = 0 - def __next__(self) -> (Tensor, Tensor): + def __next__(self) -> Tuple[Tensor, Tensor]: if self.curr > len(self.dl) - 1: raise StopIteration @@ -28,8 +29,7 @@ class _OptimusDLIter(): return x, y - -class _OptimusDL(): +class _OptimusDL(Iterable): def __init__(self, ds, tok, bs, seq_len, shuffle, device): """ See 'OptimusDataLoader'. @@ -80,7 +80,7 @@ class _OptimusDL(): """ self.device = device - def __iter__(self) -> _OptimusDLIter: + def __iter__(self) -> Iterator[_OptimusDLIter]: """ Return an iterator over the dataloader object. diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py index 81d2829bf6632b00ba0938e30fcad202c3d6ef49..f84b0fe67aed189c85d8b2519661b70e325ebeb6 100644 --- a/optimus/datasets/tinystories.py +++ b/optimus/datasets/tinystories.py @@ -27,13 +27,13 @@ _EXTRACTED_FILES = { class TinyStoriesDataset(Dataset): - def __init__(self, root: str = None, split: str = 'train'): + def __init__(self, root: str | None = None, split: str = 'train'): """ TinyStories dataset. Args: - root (str): Directory where the dataset is saved. Defaults to - os.path.expanduser('~/.cache/optimus/'). + root (str | None): Directory where the dataset is saved. Defaults to + os.path.expanduser('~/.cache/optimus/') if None. split (str): Split to be returned. Can be 'train', 'test' or 'valid'. diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py index 50b170f68befb2ec1511a8b80b824a1361e220d1..daaecc91c96c5d2896ab45f0b6fd2f6044c0cefe 100644 --- a/optimus/datasets/wikitext103.py +++ b/optimus/datasets/wikitext103.py @@ -20,13 +20,13 @@ _EXTRACTED_FILES = { class WikiText103Dataset(Dataset): - def __init__(self, root: str = None, split: str = 'train'): + def __init__(self, root: str | None = None, split: str = 'train'): """ WikiText103 dataset. Args: - root (str): Directory where the dataset is saved. Defaults to - os.path.expanduser('~/.cache/optimus/'). + root (str | None): Directory where the dataset is saved. Defaults to + os.path.expanduser('~/.cache/optimus/') if None. split (str): Split to be returned. Can be 'train', 'test' or 'valid'. diff --git a/optimus/example_training.py b/optimus/example_training.py index d92d38cca7db059391d7de70b61bcbc869f95af3..b3619a242703fd734541fa6f227f7dc398098c84 100644 --- a/optimus/example_training.py +++ b/optimus/example_training.py @@ -1,12 +1,6 @@ -import sys -import time -import math -import argparse - import fire import torch from torch import nn -from torch.utils.data import Dataset from datasets.wikitext103 import WikiText103Dataset from tokenizer import Tokenizer diff --git a/optimus/model.py b/optimus/model.py index 108feafc6b9ec1547ed97a028d3e88051cc47ad5..5a432655c81579bcc42b0b001db2750bd8d0d735 100644 --- a/optimus/model.py +++ b/optimus/model.py @@ -53,7 +53,7 @@ class Attention(nn.Module): def forward(self, x: torch.Tensor, - mask: torch.Tensor = None) -> torch.Tensor: + mask: torch.Tensor | None = None) -> torch.Tensor: bs, seq_len, _ = x.shape # (batch_size, seq_len, dim) diff --git a/optimus/tokenizer.py b/optimus/tokenizer.py index 170d4a5caf946894809fbf4a27a9e2f8b5dfe8ba..0a9ac95a337941ad3f732fcaa89d881f50eb8ca0 100644 --- a/optimus/tokenizer.py +++ b/optimus/tokenizer.py @@ -77,12 +77,12 @@ class Tokenizer(): """ assert type(input) is str - input = self.sp_model.encode_as_ids(input) + ids: List[int] = self.sp_model.encode_as_ids(input) if bos: - input = [self.bos_id] + input + ids = [self.bos_id] + ids if eos: - input = input + [self.eos_id] - return input + ids = ids + [self.eos_id] + return ids def encode_as_pieces(self, input: str) -> List[str]: """ diff --git a/optimus/trainer.py b/optimus/trainer.py index 43c63714a1c3ae9336683f08627c227a046d6581..065d61a49f2a9c3c7ef57f1d51c935f0e2e0e97b 100644 --- a/optimus/trainer.py +++ b/optimus/trainer.py @@ -1,12 +1,13 @@ -import os import time import math +from typing import Callable import torch import torch.nn as nn +import torch.optim as optim +from fastprogress.fastprogress import master_bar, progress_bar, format_time from dataloader import OptimusDataLoader -from fastprogress.fastprogress import master_bar, progress_bar, format_time class Trainer(): @@ -14,8 +15,8 @@ class Trainer(): def __init__(self, dl: OptimusDataLoader, model: nn.Module, - criterion: callable, - optimizer: torch.optim.Optimizer, + criterion: Callable, + optimizer: optim.Optimizer, lr: float, grad_acc_steps: int, grad_clip_norm: float,