diff --git a/inference.py b/inference.py
index 99426b7bcd781609fc89ccc1f9ee179656222f5f..cfa5823e9856836ab7dfdfd1327b06af5af82cee 100644
--- a/inference.py
+++ b/inference.py
@@ -1,7 +1,7 @@
-import os
 import time
 import warnings
 import logging
+from pathlib import Path
 
 import fire
 import torch
@@ -37,19 +37,28 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor:
     return token
 
 
-def main(model_path: str = 'model.pth',
-         hf_tokenizer_name: str = 'gpt2',
-         prompt: str | None = None,
-         max_seq_len: int = 512,
-         temperature: float = 0.6,
-         top_p: float = 0.9,
-         use_fp16: bool = True,
-         device: str = 'cuda'):
+def main(
+    # model + tokenizer
+    model_config_path: Path = Path('config.json'),
+    weights_path: Path = Path('weights.pth'),
+    hf_tokenizer_name: str = 'gpt2',
+
+    # prompt
+    prompt: str | None = None,
+
+    # inference args
+    max_seq_len: int = 512,
+    temperature: float = 0.6,
+    top_p: float = 0.9,
+    use_fp16: bool = True,
+    device: str = 'cuda'
+):
     """
     Run the main inference loop for the model.
 
     Args:
-        model_path (str): Path to the saved model.
+        model_config_path (Path): Path to config for model.
+        weights_path (Path): Path to trained model weights.
         hf_tokenizer_name (str): HuggingFace tokenizer. Note: Ensure the same
             tokenizer is used for both training and testing, otherwise the
             results might be undefined.
@@ -72,7 +81,8 @@ def main(model_path: str = 'model.pth',
     """
 
     logger.info(f"Running with:\n"
-        f"\t- model path: '{model_path}'\n"
+        f"\t- model config path: '{model_config_path}'\n"
+        f"\t- weights file: '{weights_path}'\n"
         f"\t- huggingface tokenizer: '{hf_tokenizer_name}'\n"
         f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n"
         f"\t- max context length: {max_seq_len}\n"
@@ -103,34 +113,19 @@ def main(model_path: str = 'model.pth',
 
     logger.info('Loading model from disk...')
 
-    # load state from file
-    assert os.path.exists(model_path)
-    state = torch.load(model_path, map_location=device)
-
-    # these will be loaded into the model automatically later anyway, but we
-    # need at least `dim` now to be able to compute positional encodings
-    # correctly, and `vocab_sz` to make sure the tokenizer works
-    vocab_sz = int(state['vocab_sz'])
-    n_layers = int(state['n_layers'])
-    n_heads = int(state['n_heads'])
-    dim = int(state['dim'])
-    p_drop = float(state['p_drop'])
-    weight_tying = bool(state['weight_tying'])
-
-    assert vocab_sz == len(tokenizer), ('The tokenizer passed for inference is '
-        'different from the tokenizer used for training! This will result in '
-        'erroneous generation!')
-
-    # create model, load weights
-    config = OptimusConfig(vocab_size=vocab_sz,
-                           num_hidden_layers=n_layers,
-                           num_attention_heads=n_heads,
-                           hidden_size=dim,
-                           attention_dropout=p_drop,
-                           tie_word_embeddings=weight_tying)
+    # create model and load weights
+    config = OptimusConfig.from_json_file(model_config_path)
     model = OptimusTransformer(config)
-    model.load_state_dict(state, strict=True)
+    model.load_state_dict(torch.load(weights_path, map_location=device), strict=True)
+
+    # alternatively, comment the 3 lines above, and uncomment / change / adapt
+    # the following 3 lines to use a HuggingFace model
+    # from transformers import AutoTokenizer, AutoModelForCausalLM
+    # tokenizer = AutoTokenizer.from_pretrained('gpt2')
+    # model = AutoModelForCausalLM.from_pretrained('gpt2')
+
     model.eval()
+    model.to(device)
 
     logger.info(f'Loaded model on device {device}!')
 
@@ -147,7 +142,7 @@ def main(model_path: str = 'model.pth',
         input_sentence = input('User: ')
 
     # tokenize input
-    inp = torch.tensor(tokenizer(input_sentence), dtype=torch.long)
+    inp = torch.tensor(tokenizer(input_sentence)['input_ids'], dtype=torch.long)
     inp.unsqueeze_(0) # (1, seq_len)
 
     seq_len = inp.shape[-1]
@@ -160,7 +155,8 @@ def main(model_path: str = 'model.pth',
         while seq_len != max_seq_len:
 
             # get model output
-            output = model(inp) # (1, seq_len, vocab_sz)
+            output = model(inp)
+            output = output.get('logits', output) # (1, seq_len, vocab_sz)
             toks_generated += 1
 
             # get the logits for the last token
@@ -185,7 +181,7 @@ def main(model_path: str = 'model.pth',
 
             seq_len = inp.shape[-1]
 
-    logger.info(f"Model output: {' '.join(tokenizer.decode(inp.tolist()))}")
+    print(f"Model output: {''.join(tokenizer.decode(inp.squeeze(0)))}")
     logger.info(f'Tokens / second: {toks_generated / (time.time() - start_time):.2f}')
 
     logger.info('Finished inference!')