Skip to content
Snippets Groups Projects
Unverified Commit a092db0a authored by Alexandru-Mihai GHERGHESCU's avatar Alexandru-Mihai GHERGHESCU
Browse files

Fix a number of type issues

Fix problems with some types. This enables Python's static type checks
to correctly identify some issues before runtime.
parent 5521d924
No related branches found
No related tags found
1 merge request!11Fix a number of issues with the infrastructure, no major rework
import time
import random
from typing import Tuple, Iterator, Iterable
import torch
from torch import Tensor
......@@ -8,7 +9,7 @@ from torch.utils.data import Dataset
from tokenizer import Tokenizer
class _OptimusDLIter():
class _OptimusDLIter(Iterator):
def __init__(self, dl):
"""
_OptimusDL iterator.
......@@ -17,7 +18,7 @@ class _OptimusDLIter():
self.dl = dl
self.curr = 0
def __next__(self) -> (Tensor, Tensor):
def __next__(self) -> Tuple[Tensor, Tensor]:
if self.curr > len(self.dl) - 1:
raise StopIteration
......@@ -28,8 +29,7 @@ class _OptimusDLIter():
return x, y
class _OptimusDL():
class _OptimusDL(Iterable):
def __init__(self, ds, tok, bs, seq_len, shuffle, device):
"""
See 'OptimusDataLoader'.
......@@ -80,7 +80,7 @@ class _OptimusDL():
"""
self.device = device
def __iter__(self) -> _OptimusDLIter:
def __iter__(self) -> Iterator[_OptimusDLIter]:
"""
Return an iterator over the dataloader object.
......
......@@ -27,13 +27,13 @@ _EXTRACTED_FILES = {
class TinyStoriesDataset(Dataset):
def __init__(self, root: str = None, split: str = 'train'):
def __init__(self, root: str | None = None, split: str = 'train'):
"""
TinyStories dataset.
Args:
root (str): Directory where the dataset is saved. Defaults to
os.path.expanduser('~/.cache/optimus/').
root (str | None): Directory where the dataset is saved. Defaults to
os.path.expanduser('~/.cache/optimus/') if None.
split (str): Split to be returned. Can be 'train', 'test' or
'valid'.
......
......@@ -20,13 +20,13 @@ _EXTRACTED_FILES = {
class WikiText103Dataset(Dataset):
def __init__(self, root: str = None, split: str = 'train'):
def __init__(self, root: str | None = None, split: str = 'train'):
"""
WikiText103 dataset.
Args:
root (str): Directory where the dataset is saved. Defaults to
os.path.expanduser('~/.cache/optimus/').
root (str | None): Directory where the dataset is saved. Defaults to
os.path.expanduser('~/.cache/optimus/') if None.
split (str): Split to be returned. Can be 'train', 'test' or
'valid'.
......
import sys
import time
import math
import argparse
import fire
import torch
from torch import nn
from torch.utils.data import Dataset
from datasets.wikitext103 import WikiText103Dataset
from tokenizer import Tokenizer
......
......@@ -53,7 +53,7 @@ class Attention(nn.Module):
def forward(self,
x: torch.Tensor,
mask: torch.Tensor = None) -> torch.Tensor:
mask: torch.Tensor | None = None) -> torch.Tensor:
bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
......
......@@ -77,12 +77,12 @@ class Tokenizer():
"""
assert type(input) is str
input = self.sp_model.encode_as_ids(input)
ids: List[int] = self.sp_model.encode_as_ids(input)
if bos:
input = [self.bos_id] + input
ids = [self.bos_id] + ids
if eos:
input = input + [self.eos_id]
return input
ids = ids + [self.eos_id]
return ids
def encode_as_pieces(self, input: str) -> List[str]:
"""
......
import os
import time
import math
from typing import Callable
import torch
import torch.nn as nn
import torch.optim as optim
from fastprogress.fastprogress import master_bar, progress_bar, format_time
from dataloader import OptimusDataLoader
from fastprogress.fastprogress import master_bar, progress_bar, format_time
class Trainer():
......@@ -14,8 +15,8 @@ class Trainer():
def __init__(self,
dl: OptimusDataLoader,
model: nn.Module,
criterion: callable,
optimizer: torch.optim.Optimizer,
criterion: Callable,
optimizer: optim.Optimizer,
lr: float,
grad_acc_steps: int,
grad_clip_norm: float,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment