diff --git a/optimus/example_inference.py b/optimus/example_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e55e36af378f93ed1e4655271135a354c9999d0
--- /dev/null
+++ b/optimus/example_inference.py
@@ -0,0 +1,120 @@
+import sys
+import time
+import math
+import argparse
+
+import fire
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+
+from datasets.wikitext103 import WikiText103Dataset
+from tokenizer import Tokenizer
+from dataloader import OptimusDataLoader
+from model import Transformer
+from trainer import Trainer
+
+
+def main(batch_size: int = 8,
+         grad_acc_steps: int = 1,
+         seq_len: int = 512,
+         lr_max: float = 1e-4,
+         grad_clip_norm: float = 1.0,
+         epochs: int = 1,
+         tokenizer_path: str = 'optimus.model',
+         checkpoints_path: str = 'best_model.pth',
+         n_layers: int = 6,
+         dim: int = 512,
+         n_heads: int = 8,
+         dropout: float = 0.0,
+         device: str = 'cuda',
+         prompt = "Once upon time"
+):
+    """
+    Run the main training loop for the model.
+
+    Args:
+        batch_size (int): Batch size for training.
+        grad_acc_steps (int): Number of batches to accumulate gradients for
+            before running backpropagation to update weights.
+        seq_len (int): Context length for training.
+        lr_max (float): Maximum learning rate, used for one-cycle scheduling.
+        grad_clip_norm (float): Gradient clipping value for gradient's norm.
+        epochs (int): Number of epochs to train for.
+        tokenizer_path (str): Path to the tokenizer model.
+        checkpoints_path (str): Where to save the trained model. Should be a .pt
+            or .pth file.
+        n_layers (int): Number of layers for the model.
+        dim (int): Dimension of the model.
+        n_heads (int): Number of heads inside an attention layer for the model.
+        dropout (float): Dropout to use for the model.
+        device (str): Device where to train the model. Viable options are 'cpu',
+            'cuda', 'cuda:2' etc.
+
+    """
+
+    print(f"Running with:\n"
+        f"\t- batch size: {batch_size}\n"
+        f"\t- gradient accumulation steps: {grad_acc_steps}\n"
+        f"\t- context length: {seq_len}\n"
+        f"\t- max learning rate: {lr_max}\n"
+        f"\t- gradient clipping norm: {grad_clip_norm}\n"
+        f"\t- epochs: {epochs}\n"
+        f"\t- tokenizer: {tokenizer_path}\n"
+        f"\t- checkpoints path: {checkpoints_path}\n"
+        f"\t- model layers: {n_layers}\n"
+        f"\t- model dimension: {dim}\n"
+        f"\t- model attention heads: {n_heads}\n"
+        f"\t- model dropout: {dropout}\n"
+        f"\t- training on device: {device}\n"
+        f"Please see '--help' if you want to change these settings")
+
+    # load tokenizer
+    tok = Tokenizer(model_path=tokenizer_path)
+
+    # create model
+    model = Transformer(len(tok),
+                        n_layers=n_layers,
+                        dim=dim,
+                        n_heads=n_heads,
+                        p_drop=dropout,
+                        weight_tying=False)
+    
+    # load checkpoint
+    checkpoint = torch.load(checkpoints_path, map_location=device)
+    state_dict = checkpoint
+    model.load_state_dict(state_dict)
+
+    model.eval()
+    model = model.to(device)
+
+    _total_params = sum(p.numel() for p in model.parameters())
+    print(f"Number of model parameters: {_total_params}")
+
+    # create trainer and start fitting
+    trainer = Trainer(dl=None,
+                      model=model,
+                      criterion=None,
+                      optimizer=None,
+                      lr=lr_max,
+                      grad_acc_steps=grad_acc_steps,
+                      grad_clip_norm=grad_clip_norm,
+                      model_save_path=checkpoints_path,
+                      progress_bar=True)
+
+    # Run the generation for 128 tokens
+    dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
+    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
+    ctx = torch.amp.autocast(device_type=device, dtype=ptdtype)
+
+
+    with torch.no_grad():
+        with ctx:
+            start_ids = tok.encode(prompt, bos=False, eos=False)
+            x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
+
+    print(f"Finished training! Best model weights saved at '{checkpoints_path}'")
+
+
+if __name__=="__main__":
+    fire.Fire(main)
diff --git a/optimus/trainer.py b/optimus/trainer.py
index 7ba83856d53089c70a9c4ccb1a3b03b3f20bbdd8..626c214da88c086071c6f27f2bf77adcbfc79223 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -196,3 +196,28 @@ class Trainer():
                 f"\tTotal valid batches: {len(self.dl.test):10d} | "
                 f"Valid loss: {self.val_loss: 7.2f} | "
                 f"Valid perplexity: {self.val_ppl: 8.2f}")
+            
+    @torch.no_grad()
+    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
+        """
+        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
+        the sequence max_new_tokens times, feeding the predictions back into the model each time.
+        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
+        """
+
+        for _ in range(max_new_tokens):
+            idx_cond = idx 
+            # TODO: Once we have access to context_size change the line to
+            # if idx.size(1) <= self.context_size else idx[:, -self.context_size:]
+            logits = self(idx_cond)
+
+            logits = logits[:, -1, :] / temperature
+            if top_k is not None:
+                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+                logits[logits < v[:, [-1]]] = -float('Inf')
+            
+            probs = F.softmax(logits, dim=-1)
+            idx_next = torch.multinomial(probs, num_samples=1)
+            idx = torch.cat((idx, idx_next), dim=1)
+        
+        return idx