diff --git a/inference.py b/inference.py index 99426b7bcd781609fc89ccc1f9ee179656222f5f..cfa5823e9856836ab7dfdfd1327b06af5af82cee 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,7 @@ -import os import time import warnings import logging +from pathlib import Path import fire import torch @@ -37,19 +37,28 @@ def sample_top_p(probs: Tensor, top_p: float) -> Tensor: return token -def main(model_path: str = 'model.pth', - hf_tokenizer_name: str = 'gpt2', - prompt: str | None = None, - max_seq_len: int = 512, - temperature: float = 0.6, - top_p: float = 0.9, - use_fp16: bool = True, - device: str = 'cuda'): +def main( + # model + tokenizer + model_config_path: Path = Path('config.json'), + weights_path: Path = Path('weights.pth'), + hf_tokenizer_name: str = 'gpt2', + + # prompt + prompt: str | None = None, + + # inference args + max_seq_len: int = 512, + temperature: float = 0.6, + top_p: float = 0.9, + use_fp16: bool = True, + device: str = 'cuda' +): """ Run the main inference loop for the model. Args: - model_path (str): Path to the saved model. + model_config_path (Path): Path to config for model. + weights_path (Path): Path to trained model weights. hf_tokenizer_name (str): HuggingFace tokenizer. Note: Ensure the same tokenizer is used for both training and testing, otherwise the results might be undefined. @@ -72,7 +81,8 @@ def main(model_path: str = 'model.pth', """ logger.info(f"Running with:\n" - f"\t- model path: '{model_path}'\n" + f"\t- model config path: '{model_config_path}'\n" + f"\t- weights file: '{weights_path}'\n" f"\t- huggingface tokenizer: '{hf_tokenizer_name}'\n" f"\t- prompt: '{prompt[:30] + ('' if len(prompt) <= 30 else '...') if prompt is not None else '(empty)'}'\n" f"\t- max context length: {max_seq_len}\n" @@ -103,34 +113,19 @@ def main(model_path: str = 'model.pth', logger.info('Loading model from disk...') - # load state from file - assert os.path.exists(model_path) - state = torch.load(model_path, map_location=device) - - # these will be loaded into the model automatically later anyway, but we - # need at least `dim` now to be able to compute positional encodings - # correctly, and `vocab_sz` to make sure the tokenizer works - vocab_sz = int(state['vocab_sz']) - n_layers = int(state['n_layers']) - n_heads = int(state['n_heads']) - dim = int(state['dim']) - p_drop = float(state['p_drop']) - weight_tying = bool(state['weight_tying']) - - assert vocab_sz == len(tokenizer), ('The tokenizer passed for inference is ' - 'different from the tokenizer used for training! This will result in ' - 'erroneous generation!') - - # create model, load weights - config = OptimusConfig(vocab_size=vocab_sz, - num_hidden_layers=n_layers, - num_attention_heads=n_heads, - hidden_size=dim, - attention_dropout=p_drop, - tie_word_embeddings=weight_tying) + # create model and load weights + config = OptimusConfig.from_json_file(model_config_path) model = OptimusTransformer(config) - model.load_state_dict(state, strict=True) + model.load_state_dict(torch.load(weights_path, map_location=device), strict=True) + + # alternatively, comment the 3 lines above, and uncomment / change / adapt + # the following 3 lines to use a HuggingFace model + # from transformers import AutoTokenizer, AutoModelForCausalLM + # tokenizer = AutoTokenizer.from_pretrained('gpt2') + # model = AutoModelForCausalLM.from_pretrained('gpt2') + model.eval() + model.to(device) logger.info(f'Loaded model on device {device}!') @@ -147,7 +142,7 @@ def main(model_path: str = 'model.pth', input_sentence = input('User: ') # tokenize input - inp = torch.tensor(tokenizer(input_sentence), dtype=torch.long) + inp = torch.tensor(tokenizer(input_sentence)['input_ids'], dtype=torch.long) inp.unsqueeze_(0) # (1, seq_len) seq_len = inp.shape[-1] @@ -160,7 +155,8 @@ def main(model_path: str = 'model.pth', while seq_len != max_seq_len: # get model output - output = model(inp) # (1, seq_len, vocab_sz) + output = model(inp) + output = output.get('logits', output) # (1, seq_len, vocab_sz) toks_generated += 1 # get the logits for the last token @@ -185,7 +181,7 @@ def main(model_path: str = 'model.pth', seq_len = inp.shape[-1] - logger.info(f"Model output: {' '.join(tokenizer.decode(inp.tolist()))}") + print(f"Model output: {''.join(tokenizer.decode(inp.squeeze(0)))}") logger.info(f'Tokens / second: {toks_generated / (time.time() - start_time):.2f}') logger.info('Finished inference!')