diff --git a/README.md b/README.md
index 24b864e691254057ab94644e280a8b384355b89d..66a8a726bcbaa30fd0f252a4747a24d578587268 100644
--- a/README.md
+++ b/README.md
@@ -22,9 +22,14 @@ The requirements to train a model of a custom size from scratch are:
    provide). Additionally, `optimus/trainer.py` can be directly modified, for
    options currently not exposed through an interface.
 
+> [!TIP]
+> You can choose which GPU's to train on, using the environment variable
+> `CUDA_VISIBLE_DEVICES`. For example, you can train on the second available GPU
+> on the system with `CUDA_VISIBLE_DEVICES=1 python optimus/example_training.py`.
+
 ## Required packages
 
 There are a number of packages required to run the thing. Get your closest
 Python retailer and ask him to run the following command:
 
-`pip install torch fire sentencepiece fastprogress`
+`pip install torch fire sentencepiece fastprogress matplotlib`
diff --git a/optimus/dataloader.py b/optimus/dataloader.py
index 65869be320a434a6a2f87a6f3a9737d87a9c8d6f..b1e00c961547fdd071d294743b8545636833f823 100644
--- a/optimus/dataloader.py
+++ b/optimus/dataloader.py
@@ -1,5 +1,6 @@
 import time
 import random
+from typing import Tuple, Iterator, Iterable
 
 import torch
 from torch import Tensor
@@ -8,7 +9,7 @@ from torch.utils.data import Dataset
 from tokenizer import Tokenizer
 
 
-class _OptimusDLIter():
+class _OptimusDLIter(Iterator):
     def __init__(self, dl):
         """
         _OptimusDL iterator.
@@ -17,7 +18,7 @@ class _OptimusDLIter():
         self.dl = dl
         self.curr = 0
 
-    def __next__(self) -> (Tensor, Tensor):
+    def __next__(self) -> Tuple[Tensor, Tensor]:
         if self.curr > len(self.dl) - 1:
             raise StopIteration
 
@@ -28,8 +29,7 @@ class _OptimusDLIter():
 
         return x, y
 
-
-class _OptimusDL():
+class _OptimusDL(Iterable):
     def __init__(self, ds, tok, bs, seq_len, shuffle, device):
         """
         See 'OptimusDataLoader'.
@@ -51,8 +51,24 @@ class _OptimusDL():
 
         print(f"Done. Took {time.time() - start:.2f}s.")
 
-        self.num_batches = torch.cat(self._data,
-                                     dim=-1).shape[0] // self.bs // self.seq_len
+        # pre-calculate the number of batches in the dataset
+
+        # Note: there's a special case we need to be careful about; since the
+        # predictions are simply the inputs shifted to the right by one value;
+        # there's a case when the dataset ends before we can get these
+        # shifted-right predictions; this occurs iff `batch_len % seq_len == 0`;
+        # to avoid this, we have to be explicit about the available number of
+        # batches (by simply subtracting 1 from the total number of available
+        # batches)
+        dataset_stream_len = 0
+        for sample in self._data:
+            dataset_stream_len += len(sample)
+
+        batch_len = dataset_stream_len // self.bs
+        self.num_batches = batch_len // self.seq_len
+
+        if batch_len % self.seq_len == 0:
+            self.num_batches -= 1
 
     def _process_data_before_iter(self):
         data = self._data
@@ -80,7 +96,7 @@ class _OptimusDL():
         """
         self.device = device
 
-    def __iter__(self) -> _OptimusDLIter:
+    def __iter__(self) -> Iterator[_OptimusDLIter]:
         """
         Return an iterator over the dataloader object.
 
diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py
index 81d2829bf6632b00ba0938e30fcad202c3d6ef49..f84b0fe67aed189c85d8b2519661b70e325ebeb6 100644
--- a/optimus/datasets/tinystories.py
+++ b/optimus/datasets/tinystories.py
@@ -27,13 +27,13 @@ _EXTRACTED_FILES = {
 
 class TinyStoriesDataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         TinyStories dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py
index 50b170f68befb2ec1511a8b80b824a1361e220d1..daaecc91c96c5d2896ab45f0b6fd2f6044c0cefe 100644
--- a/optimus/datasets/wikitext103.py
+++ b/optimus/datasets/wikitext103.py
@@ -20,13 +20,13 @@ _EXTRACTED_FILES = {
 
 class WikiText103Dataset(Dataset):
 
-    def __init__(self, root: str = None, split: str = 'train'):
+    def __init__(self, root: str | None = None, split: str = 'train'):
         """
         WikiText103 dataset.
 
         Args:
-            root (str): Directory where the dataset is saved. Defaults to
-                os.path.expanduser('~/.cache/optimus/').
+            root (str | None): Directory where the dataset is saved. Defaults to
+                os.path.expanduser('~/.cache/optimus/') if None.
             split (str): Split to be returned. Can be 'train', 'test' or
                 'valid'.
 
diff --git a/optimus/example_training.py b/optimus/example_training.py
index a1732bed775c40bd88c3dc1e949dc8e6ad9e7945..b3619a242703fd734541fa6f227f7dc398098c84 100644
--- a/optimus/example_training.py
+++ b/optimus/example_training.py
@@ -1,12 +1,6 @@
-import sys
-import time
-import math
-import argparse
-
 import fire
 import torch
 from torch import nn
-from torch.utils.data import Dataset
 
 from datasets.wikitext103 import WikiText103Dataset
 from tokenizer import Tokenizer
@@ -23,8 +17,8 @@ def main(batch_size: int = 8,
          epochs: int = 1,
          tokenizer_path: str = 'optimus.model',
          checkpoints_path: str = 'best_model.pth',
-         n_layers: int = 6,
          dim: int = 512,
+         n_layers: int = 6,
          n_heads: int = 8,
          dropout: float = 0.0,
          device: str = 'cuda'):
@@ -42,8 +36,8 @@ def main(batch_size: int = 8,
         tokenizer_path (str): Path to the tokenizer model.
         checkpoints_path (str): Where to save the trained model. Should be a .pt
             or .pth file.
-        n_layers (int): Number of layers for the model.
         dim (int): Dimension of the model.
+        n_layers (int): Number of layers for the model.
         n_heads (int): Number of heads inside an attention layer for the model.
         dropout (float): Dropout to use for the model.
         device (str): Device where to train the model. Viable options are 'cpu',
@@ -60,8 +54,8 @@ def main(batch_size: int = 8,
         f"\t- epochs: {epochs}\n"
         f"\t- tokenizer: {tokenizer_path}\n"
         f"\t- checkpoints path: {checkpoints_path}\n"
-        f"\t- model layers: {n_layers}\n"
         f"\t- model dimension: {dim}\n"
+        f"\t- model layers: {n_layers}\n"
         f"\t- model attention heads: {n_heads}\n"
         f"\t- model dropout: {dropout}\n"
         f"\t- training on device: {device}\n"
diff --git a/optimus/model.py b/optimus/model.py
index 3ee8b4ae9e44947942ff58bb2979f3cb3788cb34..5a432655c81579bcc42b0b001db2750bd8d0d735 100644
--- a/optimus/model.py
+++ b/optimus/model.py
@@ -18,11 +18,14 @@ class Norm(nn.Module):
         self.eps = eps
         self.weight = nn.Parameter(torch.ones(dim))
 
-    def forward(self, x):
-        x = x.float() # compute in float32, not in fp16, since normalization needs to be accurate
+    def _norm(self, x):
         std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
-        output = x / (std + self.eps) * self.weight
-        return output.type_as(x) # return value in float32
+        return x / (std + self.eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # compute in float32, not in fp16, since normalization needs to be accurate
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
 
 
 class Attention(nn.Module):
@@ -50,7 +53,7 @@ class Attention(nn.Module):
 
     def forward(self,
                 x: torch.Tensor,
-                mask: torch.Tensor = None) -> torch.Tensor:
+                mask: torch.Tensor | None = None) -> torch.Tensor:
 
         bs, seq_len, _ = x.shape # (batch_size, seq_len, dim)
 
@@ -92,12 +95,13 @@ class Attention(nn.Module):
 class FeedForward(nn.Module):
     def __init__(self, dim: int, ffn_dim: int):
         """
-        Feed forward layer. This is a modified version of the original feed
-        forward layer from "Attention Is All You Need". The changes are
-        described in "GLU variants improve transformer"
-        (https://arxiv.org/abs/2002.05202).
+        Feed forward layer.
 
-        A summary of the above:
+        This is a modified version of the original feed forward layer from
+        "Attention Is All You Need". The changes are described in "GLU variants
+        improve transformer" (https://arxiv.org/abs/2002.05202).
+
+        A summary of the changes proposed:
         - 3 linear layers instead of 2
         - reduce dimension of the 3 matrices to 2/3 * dim to keep computation
           constant
@@ -154,20 +158,20 @@ class TransformerBlock(nn.Module):
 class Transformer(nn.Module):
     def __init__(self,
                  vocab_sz: int,
-                 n_layers: int = 6,
                  dim: int = 512,
+                 n_layers: int = 6,
                  n_heads: int = 8,
                  p_drop: float = 0.0,
                  weight_tying: bool = False):
         """
         A transformer implementation. Most of it is derived directly from
         "Attention is All You Need" (see https://arxiv.org/abs/1706.03762),
-        with additional changes from more recent times.
+        with additional changes from recent research.
 
         Args:
             vocab_sz (int): Vocabulary size.
-            n_layers (int): The number of layers in the model.
             dim (int): The dimension of embeddings in the model.
+            n_layers (int): The number of layers in the model.
             n_heads (int): The number of attention heads.
             p_drop (float): Dropout probability. Dropout is applied to input
                 embeddings, the outputs of attention layers, and the outputs of
@@ -180,6 +184,15 @@ class Transformer(nn.Module):
 
         """
         super().__init__()
+
+        # add as buffers so they get saved along with the model with torch.save()
+        self.register_buffer('vocab_sz', torch.tensor(vocab_sz))
+        self.register_buffer('n_layers', torch.tensor(n_layers))
+        self.register_buffer('n_heads', torch.tensor(n_heads))
+        self.register_buffer('dim', torch.tensor(dim))
+        self.register_buffer('p_drop', torch.tensor(p_drop))
+        self.register_buffer('weight_tying', torch.tensor(weight_tying))
+
         self.embeddings = nn.Embedding(vocab_sz, dim)
         self.positional_encodings = self._compute_freqs(dim)
         self.input_dropout = nn.Dropout(p=p_drop)
@@ -244,4 +257,4 @@ class Transformer(nn.Module):
         x = self.output_norm(x)
         output = self.output(x) # (batch_size, seq_len, vocab_sz)
 
-        return output#.float() # return as float32
+        return output
diff --git a/optimus/tokenizer.py b/optimus/tokenizer.py
index 170d4a5caf946894809fbf4a27a9e2f8b5dfe8ba..0a9ac95a337941ad3f732fcaa89d881f50eb8ca0 100644
--- a/optimus/tokenizer.py
+++ b/optimus/tokenizer.py
@@ -77,12 +77,12 @@ class Tokenizer():
 
         """
         assert type(input) is str
-        input = self.sp_model.encode_as_ids(input)
+        ids: List[int] = self.sp_model.encode_as_ids(input)
         if bos:
-            input = [self.bos_id] + input
+            ids = [self.bos_id] + ids
         if eos:
-            input = input + [self.eos_id]
-        return input
+            ids = ids + [self.eos_id]
+        return ids
 
     def encode_as_pieces(self, input: str) -> List[str]:
         """
diff --git a/optimus/trainer.py b/optimus/trainer.py
index 631160c81e76ac84048779bcc30d202a24ba731e..7ba83856d53089c70a9c4ccb1a3b03b3f20bbdd8 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -1,12 +1,13 @@
-import os
 import time
 import math
+from typing import Callable
 
 import torch
 import torch.nn as nn
+import torch.optim as optim
+from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 from dataloader import OptimusDataLoader
-from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 
 class Trainer():
@@ -14,8 +15,8 @@ class Trainer():
     def __init__(self,
                  dl: OptimusDataLoader,
                  model: nn.Module,
-                 criterion: callable,
-                 optimizer: torch.optim.Optimizer,
+                 criterion: Callable,
+                 optimizer: optim.Optimizer,
                  lr: float,
                  grad_acc_steps: int,
                  grad_clip_norm: float,
@@ -101,11 +102,12 @@ class Trainer():
     def _do_epoch_train(self):
         self.model.train() # put model in training mode
 
-        # compute average train loss, train ppl and ms/batch every ~200 batches
-        # (depending on gradient accumulation steps), or every 10% of training
-        # dataset (whichever is smaller)
+        # compute average train loss, train ppl and ms/batch every ~200 batches,
+        # or every 10% of training dataset (whichever is smaller), rounded to
+        # gradient accumulation steps
         self.ms_per_batch = 0.
-        est_interval = int(max(min(200 // self.grad_acc_steps, 0.1 * len(self.dl.train)), 1)) * self.grad_acc_steps
+        total_loss = 0.
+        est_interval = int(max(min(200, 0.1 * len(self.dl.train)), 1)) // self.grad_acc_steps * self.grad_acc_steps
         start_time = time.time()
 
         # progress bar for batches
@@ -143,7 +145,7 @@ class Trainer():
                 # update train loss, train ppl and estimated ms/batch
                 if (i + 1) % est_interval == 0:
                     self.ms_per_batch = (time.time() - start_time) * 1000 / est_interval
-                    self.train_loss = total_loss / est_interval
+                    self.train_loss = (total_loss * self.grad_acc_steps) / est_interval
                     self.train_ppl = math.exp(self.train_loss)
 
                     total_loss = 0.
@@ -153,10 +155,6 @@ class Trainer():
                                         f"~{self.ms_per_batch:.2f} ms/batch | " \
                                         f" lr: {lr:.7f}"
 
-        # account for last batches when computing average train loss
-        self.train_loss = total_loss / (len(self.dl.train) % est_interval - 1)
-        self.train_ppl = math.exp(self.train_loss)
-
         pb.on_iter_end()
 
     def _do_epoch_validate(self):