diff --git a/optimus/trainer.py b/optimus/trainer.py
index ef011a24ec710279215b7c334d93bcce5f42cfbd..39db79ab5c64978d4f4f161de094bad59ebff78e 100644
--- a/optimus/trainer.py
+++ b/optimus/trainer.py
@@ -1,221 +1,250 @@
+import json
 from pathlib import Path
-from typing import Optional, Callable
+from typing import Optional
 
+from tqdm import tqdm
 import torch
 import torch.nn as nn
-from torch import optim
 from torch.utils.data import DataLoader
 from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-from fastprogress.fastprogress import master_bar, progress_bar, format_time
 
 
-class Trainer():
+class TrainingArguments():
     """
-    Trainer implementation for Optimus models.
+    Training arguments class to hold important switches and knobs related to
+    training.
 
     Args:
-        dl (OptimusDataLoader): Dataloader used to train the model.
-        model (nn.Module): Model to train.
-        criterion (callable): A suitable loss function. The trainer assumes
-            nn.CrossEntropyLoss, though other loss funcs (like
-            label-smoothed cross entropy loss) can be used.
-        optimizer (torch.optim.Optimizer): Optimizer to use for training.
-        lr (float): Max learning rate value to use for one-cycle scheduling.
-        grad_acc_steps (int): Number of gradient accumulation steps before
-            running backpropagation. Setting to 1 effectively disables
-            gradient accumulation.
-        grad_clip_norm (float): Gradient clipping norm value.
-        model_save_path (str): The best model (based on validation loss) is
-            saved to the specified path.
-        use_fp16 (bool): Whether to train the model in 16-bit floating point
-            precision. If such hardware is not supported, a warning is
-            issued and normal 32-bit precision is used instead.
-        progress_bar (bool): Whether to show a progress bar in console while
-            training. This is automatically disabled if output is a file,
-            however some stats are printed after finishing epochs. If False,
-            no stats are printed at all during training, whether the output
-            is a console or a file.
+        device (torch.device): GPU device on which to train.
+        log_steps (int): Log training progress each number of steps. This
+            considers number of updates, so gradient_accumulation_steps
+            influences this.
+        show_progress (bool, defaults to True): Whether to show progress during
+            training.
+        seed (int): Seed used for reproducibility purposes.
+        optimizer (torch.optim.Optimizer): Optimizer used for training.
+        lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Scheduler used for
+            learning rate adjustment during training.
+        num_train_epochs (int): Number of training epochs.
+        per_device_batch_size (int): Batch size per device (GPU).
+        gradient_accumulation_steps (int): Steps to accumulate the gradient for.
+        max_grad_norm (float): Max gradient clipping norm.
+        use_fp16 (bool): Whether to train in floating point 16 accuracy. Uses
+            bfloat16 if available, otherwise regular fp16.
+        checkpoints_dir (Path): Path to where training checkpoints should be
+            saved.
+        save_steps (int): Save checkpoints every number of steps. If 0, doesn't
+            save checkpoints.
+        save_limit (int): Limit number of checkpoints to `save_limit`. Starts
+            deleting from the oldest when this number is reached.
 
     """
-    def __init__(self,
-                 model: nn.Module,
-                 train_dataloader: DataLoader,
-                 eval_dataloader: DataLoader,
-                 tokenizer: PreTrainedTokenizerBase,
-                 criterion: Callable,
-                 optimizer: optim.Optimizer,
-                 lr: float,
-                 grad_acc_steps: int,
-                 grad_clip_norm: float,
-                 model_save_path: str,
-                 use_fp16: bool,
-                 progress_bar: bool = True):
-        self.train_dataloader = train_dataloader
-        self.eval_dataloader = eval_dataloader
-        self.model = model
-        self.criterion = criterion
+    def __init__(
+        self,
+        device: torch.device,
+        log_steps: int,
+        seed: int,
+        optimizer: torch.optim.Optimizer,
+        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
+        num_train_epochs: int,
+        per_device_batch_size: int,
+        gradient_accumulation_steps: int,
+        max_grad_norm: float,
+        use_fp16: bool,
+        checkpoints_dir: Path,
+        save_steps: int,
+        save_limit: int,
+        show_progress: bool = True,
+    ):
+        self.device = device
+        self.log_steps = log_steps
+        self.show_progress = show_progress
+        self.seed = seed
         self.optimizer = optimizer
-        self.lr = lr
+        self.lr_scheduler = lr_scheduler
+        self.num_train_epochs = num_train_epochs
+        self.per_device_batch_size = per_device_batch_size
+        self.gradient_accumulation_steps = gradient_accumulation_steps
+        self.max_grad_norm = max_grad_norm
+        self.use_fp16 = use_fp16
+        self.checkpoints_dir = checkpoints_dir
+        self.save_steps = save_steps
+        self.save_limit = save_limit
 
-        assert type(grad_acc_steps) is int and grad_acc_steps > 0
-        self.grad_acc_steps = grad_acc_steps
+    @classmethod
+    def from_json_file(cls, file_path):
+        with open(file_path, 'r') as file:
+            return cls(**json.load(file))
 
-        self.grad_clip_norm = grad_clip_norm
-        self.model_save_path = model_save_path
 
-        self.use_fp16 = use_fp16
-        self.fp16_dtype = torch.float16
+class Trainer():
+    """
+    Generic PyTorch trainer implementation.
 
-        self.progress_bar = progress_bar
+    Args:
+        train_args (TrainingArguments): Training arguments.
+        model (nn.Module): Model to train.
+        train_dataloader (torch.utils.data.DataLoader): Dataloader used for model
+            training.
+        eval_dataloader (torch.utils.data.DataLoader): Dataloader used for model
+            evaluation.
+        tokenizer (PreTrainedTokenizerBase, optional): HuggingFace tokenizer
+            used for data collation purposes. Gets saved along with the model.
+            If not passed, a default collator is used.
 
-    def fit(self, n_epochs: int) -> None:
-        """
-        Fit the model, using the trainer, on the data inside the dataloader
-        object.
+    """
 
-        Args:
-            n_epochs (int): Number of epochs to train for.
+    def __init__(
+        self,
+        train_args: TrainingArguments,
+        model: nn.Module,
+        train_dataloader: DataLoader,
+        eval_dataloader: DataLoader,
+        tokenizer: Optional[PreTrainedTokenizerBase],
+    ):
+        self.args = train_args
+        self.model = model
+        self.train_dataloader = train_dataloader
+        self.eval_dataloader = eval_dataloader
+        self.tokenizer = tokenizer
 
+    def train(self) -> None:
         """
-        # this is fastai's implementation of one cycle
-        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
-            optimizer=self.optimizer,
-            max_lr=self.lr,
-            epochs=n_epochs,
-            steps_per_epoch=len(self.dl.train) // self.grad_acc_steps)
+        Training loop of the trainer.
 
-        # scaler used for mixed precision fp16 training on GPU
-        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16)
+        """
+        num_examples = len(self.train_dataloader) * self.args.per_device_batch_size
+        num_update_steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
+        max_steps = self.args.num_train_epochs * num_update_steps_per_epoch
+        global_batch_size = self.args.per_device_batch_size * self.args.gradient_accumulation_steps * 1
+
+        fp16_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
         loss_fn = torch.nn.CrossEntropyLoss()
 
-        # progress bar for epochs
-        self.mb = master_bar(list(range(n_epochs)))
+        self.progress = tqdm(range(max_steps), disable=(not self.args.show_progress))
+
+        # scaler used for mixed precision fp16 training on GPU
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.use_fp16)
 
-        # start training for n_epochs
-        for self.epoch in range(n_epochs):
+        print("***** Running training *****")
+        print(f"  Num examples = {num_examples:,}")
+        print(f"  Num epochs = {self.args.num_train_epochs:,}")
+        print(f"  Instantaneous batch size per device = {self.args.per_device_batch_size:,}")
+        print(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+        print(f"  Global batch size (w. distributed & accumulation) = {global_batch_size:,}")
+        print(f"  Total optimization steps = {max_steps:,}")
 
-            if self.progress_bar is True:
-                self.mb.update(self.epoch)
+        self.model.train()
 
-            start_time = time.time()
-            self._do_epoch_train()
-            self._do_epoch_validate()
-            self.epoch_time = time.time() - start_time
+        # start training for num_train_epochs
+        for epoch in range(self.args.num_train_epochs):
 
-            # write end of epoch stats
-            self._write_epoch_stats()
+            # needed for distributed sampler RNG state
+            if hasattr(self.train_dataloader, "set_epoch"):
+                self.train_dataloader.set_epoch(epoch)
 
-            # if better model on validation loss, save it
-            if self.val_loss < best_val_loss:
-                best_val_loss = self.val_loss
-                torch.save(self.model.state_dict(), self.model_save_path)
+            for step, inputs in enumerate(self.train_dataloader):
 
-    def _do_epoch_train(self):
-        self.model.train() # put model in training mode
+                inputs = inputs['input_ids']
+                inputs = inputs.to(self.args.device)
 
-        # compute average train loss, train perplexity and ms/batch every ~200
-        # batches, or every 10% of training dataset (whichever is smaller),
-        # rounded to gradient accumulation steps
-        est_interval = int(max(min(200, 0.1 * len(self.dl.train)) // self.grad_acc_steps, 1) * self.grad_acc_steps)
+                with torch.cuda.amp.autocast(dtype=fp16_dtype,
+                                             enabled=self.args.use_fp16):
+                    logits = self.model(inputs)
 
-        # progress bar for batches
-        pb = progress_bar(range(len(self.dl.train)), parent=self.mb)
+                    labels = inputs[..., 1:].contiguous().view(-1)
+                    logits = logits[..., :-1, :].contiguous().view(-1, self.model.module.vocab_size)
 
-        self.ms_per_batch = 0.
-        total_loss = 0.
-        start_time = time.time()
+                    loss = loss_fn(logits, labels)
 
-        for i, (x, y) in enumerate(self.dl.train):
+                    tr_loss = loss.item()
+                    loss = loss / self.args.gradient_accumulation_steps # normalize to account for gradient accumulation
 
-            if self.progress_bar is True:
-                pb.update(i)
+                self.scaler.scale(loss).backward()
 
-            # automatic mixed precision training
-            with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
-                                         enabled=self.use_fp16):
-                output = self.model(x)
-                loss = self.criterion(output.view(-1, len(self.dl.train.tok)),
-                                      y.reshape(-1))
+                # update only after gradient_accumulation_steps
+                if (step + 1) % self.args.gradient_accumulation_steps == 0:
+                    # Note: This will ignore the last few batches of the dataset,
+                    # when the gradient accumulation steps are more than 1, and the
+                    # number of batches doesn't cleanly divide by grad_acc_steps
 
-                loss = loss / self.grad_acc_steps # normalize to account for gradient accumulation
+                    # gradient clipping
+                    self.scaler.unscale_(self.args.optimizer)
+                    nn.utils.clip_grad_norm_(self.model.parameters(),
+                                             max_norm=self.args.max_grad_norm)
 
-            self.scaler.scale(loss).backward()
+                    self.scaler.step(self.args.optimizer)
+                    self.scaler.update()
+                    self.args.optimizer.zero_grad()
 
-            total_loss += loss.item()
+                    self.args.lr_scheduler.step()
+                    # lr = self.args.lr_scheduler.get_last_lr()[0]
 
-            # update only after grad_acc_steps
-            if (i + 1) % self.grad_acc_steps == 0:
-                # Note: This will ignore the last few batches of the dataset,
-                # when the gradient accumulation steps are more than 1, and the
-                # number of batches doesn't cleanly divide by grad_acc_steps
+                    if (step + 1) % self.args.log_steps * self.args.gradient_accumulation_steps == 0:
+                        print(f"Loss is {tr_loss:,}")
 
-                # gradient clipping
-                self.scaler.unscale_(self.optimizer)
-                nn.utils.clip_grad_norm_(self.model.parameters(),
-                                         max_norm=self.grad_clip_norm)
+                    self.progress.update(1)
 
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.optimizer.zero_grad()
+    # def _do_epoch_validate(self):
+    #     self.model.eval() # put model in eval mode
 
-                self.scheduler.step()
-                lr = self.scheduler.get_last_lr()[0]
+    #     total_loss = 0.
 
-                # update train loss, train ppl and estimated ms/batch
-                if (i + 1) % est_interval == 0:
-                    self.ms_per_batch = (time.time() - start_time) * 1000 / est_interval
-                    self.train_loss = (total_loss * self.grad_acc_steps) / est_interval
-                    self.train_ppl = math.exp(self.train_loss)
+    #     # progress bar for batches
+    #     pb = progress_bar(range(len(self.dl.test)), parent=self.mb)
 
-                    total_loss = 0.
-                    start_time = time.time()
+    #     with torch.no_grad():
+    #         for i, (x, y) in enumerate(self.dl.test):
 
-                self.mb.child.comment = f" | train loss: {loss.item() * self.grad_acc_steps:.4f} | " \
-                                        f"~{self.ms_per_batch:.2f} ms/batch | " \
-                                        f" lr: {lr:.7f}"
+    #             if self.progress_bar is True:
+    #                 pb.update(i)
 
-        pb.on_iter_end()
+    #             with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
+    #                                          enabled=self.use_fp16):
+    #                 output = self.model(x)
+    #                 loss = self.criterion(output.view(-1, len(self.dl.test.tok)),
+    #                                       y.reshape(-1))
 
-    def _do_epoch_validate(self):
-        self.model.eval() # put model in eval mode
+    #             total_loss += loss.item()
 
-        total_loss = 0.
+    #             self.mb.child.comment = f" | valid loss: {loss.item():.4f}"
 
-        # progress bar for batches
-        pb = progress_bar(range(len(self.dl.test)), parent=self.mb)
+    #         self.val_loss = total_loss / (len(self.dl.test) - 1)
+    #         self.val_ppl = math.exp(self.val_loss)
 
-        with torch.no_grad():
-            for i, (x, y) in enumerate(self.dl.test):
+    #         pb.on_iter_end()
 
-                if self.progress_bar is True:
-                    pb.update(i)
+    # def _write_epoch_stats(self):
+    #     if self.progress_bar is True:
+    #         epoch_time = format_time(self.epoch_time)
+    #         self.mb.write(
+    #             f"* End of epoch {self.epoch:3d}:\n"
+    #             f"\tTotal time: {epoch_time:9s} | "
+    #             f"Est. ms/batch: {self.ms_per_batch:.2f}\n"
+    #             f"\tTotal train batches: {len(self.dl.train):10d} | "
+    #             f"Train loss: {self.train_loss: 7.2f} | "
+    #             f"Train perplexity: {self.train_ppl: 8.2f}\n"
+    #             f"\tTotal valid batches: {len(self.dl.test):10d} | "
+    #             f"Valid loss: {self.val_loss: 7.2f} | "
+    #             f"Valid perplexity: {self.val_ppl: 8.2f}")
 
-                with torch.cuda.amp.autocast(dtype=self.fp16_dtype,
-                                             enabled=self.use_fp16):
-                    output = self.model(x)
-                    loss = self.criterion(output.view(-1, len(self.dl.test.tok)),
-                                          y.reshape(-1))
+    def save_model(self, save_dir: Path) -> None:
+        """
+        Save model and tokenizer to a directory.
 
-                total_loss += loss.item()
+        Args:
+            save_dir (Path): Path to save directory.
 
-                self.mb.child.comment = f" | valid loss: {loss.item():.4f}"
+        """
+        pass
 
-            self.val_loss = total_loss / (len(self.dl.test) - 1)
-            self.val_ppl = math.exp(self.val_loss)
+    def save_logs(self, log_dir: Path) -> None:
+        """
+        Save training logs to a directory.
 
-            pb.on_iter_end()
+        Args:
+            log_dir (Path): Path to log directory.
 
-    def _write_epoch_stats(self):
-        if self.progress_bar is True:
-            epoch_time = format_time(self.epoch_time)
-            self.mb.write(
-                f"* End of epoch {self.epoch:3d}:\n"
-                f"\tTotal time: {epoch_time:9s} | "
-                f"Est. ms/batch: {self.ms_per_batch:.2f}\n"
-                f"\tTotal train batches: {len(self.dl.train):10d} | "
-                f"Train loss: {self.train_loss: 7.2f} | "
-                f"Train perplexity: {self.train_ppl: 8.2f}\n"
-                f"\tTotal valid batches: {len(self.dl.test):10d} | "
-                f"Valid loss: {self.val_loss: 7.2f} | "
-                f"Valid perplexity: {self.val_ppl: 8.2f}")
+        """
+        pass
diff --git a/optimus/utils/setup_utils.py b/optimus/utils/setup_utils.py
index 39f8d58315b6a6c841012d4fe9e5d69d049e4c65..7802ea4c09a200b8f01f11fe23294231b2a385a8 100644
--- a/optimus/utils/setup_utils.py
+++ b/optimus/utils/setup_utils.py
@@ -1,4 +1,52 @@
 from transformers import AutoTokenizer
+from datasets import load_from_disk
+
+from optimus.models import OptimusTransformer
+from optimus.models.optimus import OptimusConfig
+
+
+def load_and_chunk_dataset(data_dir, seq_len):
+    tokenized_datasets = load_from_disk(data_dir)
+
+    # split dataset into chunks of seq_len
+    def split_text_fn(examples):
+        # concatenate texts into batch
+        concatenated_examples = {
+            input_ids: [token for sublist in examples[input_ids] for token in sublist]
+            for input_ids in examples.keys()
+        }
+        total_length = len(concatenated_examples[next(iter(examples.keys()))])
+
+        # drop last chunk
+        total_length = (total_length // seq_len) * seq_len
+
+        # split in chunks of size seq_len
+        result = {
+            input_ids: [text[i:i + seq_len] for i in range(0, total_length, seq_len)]
+            for input_ids, text in concatenated_examples.items()
+        }
+        return result
+
+    # apply text splitting into batches
+    tokenized_datasets = tokenized_datasets.map(
+        split_text_fn,
+        batched=True,
+        num_proc=6,
+        remove_columns=tokenized_datasets['train'].column_names
+    )
+    tokenized_datasets = tokenized_datasets.select_columns(['input_ids'])
+
+    print("Result: ")
+    print(tokenized_datasets)
+
+    ctx_len = len(tokenized_datasets['train'][0]['input_ids'])
+    n_batches = len(tokenized_datasets['train'])
+    print(f"Dataset info:")
+    print(f"  - context length: {ctx_len}")
+    print(f"  - number of batches (train set): {n_batches}")
+    print(f"  - total number of tokens: {ctx_len * n_batches}")
+
+    return tokenized_datasets
 
 
 def create_tokenizer(tokenizer_name):
@@ -12,3 +60,16 @@ def create_tokenizer(tokenizer_name):
         tokenizer.bos_token = tokenizer.eos_token
 
     return tokenizer
+
+
+def create_model(config_file, device):
+    config = OptimusConfig.from_json_file(config_file)
+    model = OptimusTransformer(config)
+    model.to(device)
+
+    print(model)
+
+    _total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print(f"Number of trainable model parameters: {_total_params}")
+
+    return model
diff --git a/training.py b/training.py
index 193de680832281147f2ee0d1990cd4a25364485d..50db31048e2e3b1e4a94ccf2478a2ba1e5b6e4a0 100644
--- a/training.py
+++ b/training.py
@@ -2,41 +2,64 @@ from pathlib import Path
 
 import fire
 import torch
-from torch import nn
 from torch.utils.data import DataLoader
-from datasets import load_from_disk
 
-from optimus.models.optimus import OptimusTransformer, OptimusConfig
-from optimus.trainer import Trainer
+from optimus.trainer import Trainer, TrainingArguments
 from optimus.utils import setup_utils
 
 
-def main(batch_size: int = 8,
-         grad_acc_steps: int = 1,
-         seq_len: int = 512,
-         lr_max: float = 1e-4,
-         grad_clip_norm: float = 1.0,
-         epochs: int = 1,
-         hf_tokenizer_name: str = 'gpt2',
-         data_dir: Path = Path('/workspace'),
-         dataset_dir: Path = Path('wikitext103_tokenized_dataset'),
-         checkpoints_path: str = 'best_model.pth',
-         dim: int = 512,
-         n_layers: int = 6,
-         n_heads: int = 8,
-         dropout: float = 0.0,
-         use_fp16: bool = True):
+def main(
+    # training args
+    batch_size: int = 2, # per GPU
+    grad_acc_steps: int = 4, # per GPU
+    seq_len: int = 4096,
+    lr_max: float = 1e-4,
+    weight_decay: float = 0.001,
+    warmup_steps: int = 1000,
+    epochs: int = 1,
+    grad_clip_norm: float = 1.0,
+    use_fp16: bool = True,
+    seed: int = 42,
+
+    # model + tokenizer
+    model_config_path: Path = Path('config.json'),
+    hf_tokenizer_name: str = 'gpt2',
+
+    # data directory (as this is usually big, it should reside in a different
+    # directory than the source code)
+    data_dir: Path = Path('.'),
+
+    # dataset dir (appended to data_dir)
+    dataset_dir: Path = Path('gpt2_tokenized_wikitext103_no_seq_len'),
+
+    # training-related dirs (appended to data_dir)
+    checkpoints_dir: Path = Path('training_checkpoints'),
+    log_dir: Path = Path('training_logs'),
+    save_dir: Path = Path('trained_model'),
+):
     """
-    Run the main training loop for the model.
+    Prepare the training arguments, then run the training loop.
 
     Args:
-        batch_size (int): Batch size for training.
+        batch_size (int): Batch size for training. In distributed setup, this is
+            per GPU.
         grad_acc_steps (int): Number of batches to accumulate gradients for
-            before running backpropagation to update weights.
-        seq_len (int): Context length for training.
+            before running backpropagation to update weights. In distributed
+            setup, this is per GPU. Global batch size is `batch_size *
+            grad_acc_steps * number of GPU's`. Adjust learning rate as needed!
+        seq_len (int): Context length for training. This is used to split the
+            dataset into batches.
         lr_max (float): Maximum learning rate, used for one-cycle scheduling.
-        grad_clip_norm (float): Gradient clipping value for gradient's norm.
+        weight_decay (float): Weight decay used for the optimizer.
+        warmup_steps (int): Warmup steps used for the optimizer.
         epochs (int): Number of epochs to train for.
+        grad_clip_norm (float): Gradient clipping value for gradient's norm.
+        use_fp16 (bool): Whether to train using floating-point 16-bits
+            precision. Bfloat16 is used if available, otherwise fp16.
+        seed (int): Seed used for reproducibility purposes. Each process will
+            have its initial PyTorch seed set to `seed + process_rank`.
+        model_config_path (Path): Path to a config file which describes the
+            model to be trained.
         hf_tokenizer_name (str): HuggingFace tokenizer name.
         data_dir (Path): Directory located in a large/fast storage, which holds
             data to be used by the model. Should also be capable to accomodate
@@ -47,16 +70,29 @@ def main(batch_size: int = 8,
             Should be tokenized _with the same tokenizer_ as the one used for
             training (`hf_tokenizer_name` above). The dataset should not be
             already split into batches. Given path is appended to `data_dir`.
-        checkpoints_path (str): Where to save the trained model. Should be a .pt
-            or .pth file.
-        dim (int): Dimension of the model.
-        n_layers (int): Number of layers for the model.
-        n_heads (int): Number of heads inside an attention layer for the model.
-        dropout (float): Dropout to use for the model.
-        use_fp16 (bool): Whether to train using floating-point 16-bits
-            precision.
+        checkpoints_dir (Path): Path where training checkpoints should be saved.
+            Will be created if it doesn't exist. Given path is appended to
+            `data_dir`.
+        log_dir (Path): Path where training logs should be saved. Will be
+            created if it doesn't exist. Given path is appended to `data_dir`.
+        save_dir (Path): Path where to save the model upon training completion.
+            Will be created if it doesn't exist. Given path is appended to
+            `data_dir`.
 
     """
+    # create paths
+    if isinstance(data_dir, str):
+        data_dir = Path(data_dir)
+
+    if isinstance(dataset_dir, str):
+        dataset_dir = data_dir / Path(dataset_dir)
+    else:
+        dataset_dir = data_dir / dataset_dir
+
+    # create paths
+    if isinstance(model_config_path, str):
+        model_config_path = Path(model_config_path)
+
     if isinstance(data_dir, str):
         data_dir = Path(data_dir)
 
@@ -65,34 +101,56 @@ def main(batch_size: int = 8,
     else:
         dataset_dir = data_dir / dataset_dir
 
+    if isinstance(checkpoints_dir, str):
+        checkpoints_dir = data_dir / Path(checkpoints_dir)
+    else:
+        checkpoints_dir = data_dir / checkpoints_dir
+
+    if isinstance(log_dir, str):
+        log_dir = data_dir / Path(log_dir)
+    else:
+        log_dir = data_dir / log_dir
+
+    if isinstance(save_dir, str):
+        save_dir = data_dir / Path(save_dir)
+    else:
+        save_dir = data_dir / save_dir
+
     print(f"Running with:\n"
         f"\t- batch size: {batch_size}\n"
         f"\t- gradient accumulation steps: {grad_acc_steps}\n"
         f"\t- context length: {seq_len}\n"
         f"\t- max learning rate: {lr_max}\n"
-        f"\t- gradient clipping norm: {grad_clip_norm}\n"
+        f"\t- weight decay: {weight_decay}\n"
+        f"\t- warmup steps: {warmup_steps}\n"
         f"\t- epochs: {epochs}\n"
+        f"\t- gradient clipping norm: {grad_clip_norm}\n"
+        f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
+        f"\t- seed: {seed}\n"
+        f"\t- only main rank logs: {log_on_main_rank_only}\n"
+        f"\t- model config file: {model_config_path}\n"
         f"\t- huggingface tokenizer: {hf_tokenizer_name}\n"
         f"\t- training data directory: {str(data_dir)}\n"
         f"\t- dataset directory: {str(dataset_dir)}\n"
-        f"\t- checkpoints path: {checkpoints_path}\n"
-        f"\t- model dimension: {dim}\n"
-        f"\t- model layers: {n_layers}\n"
-        f"\t- model attention heads: {n_heads}\n"
-        f"\t- model dropout: {dropout}\n"
-        f"\t- 16-bit floating-point training (fp16): {use_fp16}\n"
-        f"Please see '--help' if you want to change these settings")
+        f"\t- checkpoints directory: {str(checkpoints_dir)}\n"
+        f"\t- logging directory: {str(log_dir)}\n"
+        f"\t- saved model directory: {str(save_dir)}\n"
+        f"Please seek '--help' if you want to change any of these settings")
+
+    # set device
+    device = f'cuda'
+
+    # load dataset and split into batches
+    dataset = setup_utils.load_and_chunk_dataset(dataset_dir, seq_len)
+    dataset.set_format('torch')
 
     # load tokenizer
     tokenizer = setup_utils.create_tokenizer(hf_tokenizer_name)
 
-    # load dataset
-    dataset = load_from_disk(str(dataset_dir))
-
-    print(f'Number of examples in training set: {len(dataset['train'])}')
-    print(f'Number of examples in testing set: {len(dataset['test'])}')
+    # create model and move to device
+    model = setup_utils.create_model(model_config_path, device)
 
-    # create dataloader objects and move to device
+    # create samplers + dataloader    train_dataloader = DataLoader(
     train_dataloader = DataLoader(
         dataset['train'],
         batch_size=batch_size, # per GPU
@@ -109,48 +167,74 @@ def main(batch_size: int = 8,
         pin_memory=True, # fast CPU-GPU transfer
     )
 
-    # create model and move to device
-    config = OptimusConfig(vocab_size=len(tokenizer),
-                           num_hidden_layers=n_layers,
-                           num_attention_heads=n_heads,
-                           hidden_size=dim,
-                           attention_dropout=dropout,
-                           tie_word_embeddings=False)
-    model = OptimusTransformer(config)
-    model = model.to('cuda')
-
-    _total_params = sum(p.numel() for p in model.parameters())
-    print(f'Number of model parameters: {_total_params}')
-
-    # define loss metric
-    criterion = nn.CrossEntropyLoss()
-
-    # define optimizer
-    # see [1] for a discussion on what the epsilon value should be for amp; 1e-7
-    # is a good default for both amp and normal training
-    # [1]: https://github.com/pytorch/pytorch/issues/26218
-    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-7)
+    # create optimizer
+    optimizer = torch.optim.AdamW(
+        model.parameters(),
+        betas=(0.9, 0.999),
+        eps=1e-7,
+        weight_decay=weight_decay,
+    )
 
-    print('Starting training...')
+    # create learning rate scheduler (fastai's implementation)
+    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
+        optimizer=optimizer,
+        max_lr=lr_max,
+        epochs=epochs,
+        steps_per_epoch=len(train_dataloader) // grad_acc_steps,
+    )
+
+    # create training arguments
+    train_args = TrainingArguments(
+        device=torch.device(device),
+
+        # logging
+        log_steps=100,
 
-    # create trainer and start fitting
+        # core training
+        seed=seed,
+        optimizer=optimizer,
+        lr_scheduler=lr_scheduler,
+        num_train_epochs=epochs,
+        per_device_batch_size=batch_size, # per GPU
+        gradient_accumulation_steps=grad_acc_steps, # per GPU
+        max_grad_norm=grad_clip_norm,
+        use_fp16=use_fp16,
+
+        # training checkpointing
+        checkpoints_dir=checkpoints_dir,
+        save_steps=1000,
+        save_limit=3,
+    )
+
+    # create trainer
     trainer = Trainer(
+        train_args=train_args,
         model=model,
         train_dataloader=train_dataloader,
         eval_dataloader=eval_dataloader,
         tokenizer=tokenizer,
-        criterion=criterion,
-        optimizer=optimizer,
-        lr=lr_max,
-        grad_acc_steps=grad_acc_steps,
-        grad_clip_norm=grad_clip_norm,
-        model_save_path=checkpoints_path,
-        use_fp16=use_fp16,
-        progress_bar=True
     )
-    trainer.fit(epochs)
 
-    print(f"Finished training! Best model weights saved at '{checkpoints_path}'")
+    # create trainer
+    trainer = Trainer(
+        train_args=train_args,
+        model=model,
+        train_dataloader=train_dataloader,
+        eval_dataloader=eval_dataloader,
+        tokenizer=tokenizer,
+    )
+
+    print('Starting training...')
+
+    trainer.train()
+
+    print(f"Finished training! Saving model weights to '{str(save_dir)}'")
+
+    # save model + tokenizer
+    trainer.save_model(save_dir)
+
+    # save log data
+    trainer.save_logs(log_dir)
 
 
 if __name__=='__main__':