Switch to PyTorch Dataloader and HF datasets
Pull Request Title
Description
Wants to merge: vladb/py_dataloader into main
Type of change
-
Bug fix -
New feature -
Enhancement -
Documentation update -
Other (specify right below)
Merge request commits
- Switch to PyTorch Dataloader
Related Issues
Screenshots or GIFs
Checklist
-
I have tested the code with the changes manually. -
My code follows the project's style guidelines. -
I have documented my code for others to understand. -
I have updated documentation as needed (including README.md
, code comments and doc strings).
Reviewer Guidelines
Additional Notes
@mentions
Edited by Vlad-Andrei BĂDOIU (78692)
Merge request reports
Activity
Filter activity
@agherghescu2411 I'm still testing this, but please take a look when you have some time.
assigned to @agherghescu2411
requested review from @agherghescu2411
@agherghescu2411 Please review this when you have some time!
2 2 import torch 3 3 from torch import nn 4 4 5 from optimus.datasets import WikiText103Dataset 6 5 from optimus.tokenizers import SentencePieceTokenizer 7 from optimus.dataloader import OptimusDataLoader 6 from optimus.dataloader import * 8 7 from optimus.models import OptimusTransformer 9 8 from optimus.trainer import Trainer 10 9 from datasets import load_dataset 11 10 65 64 tok = SentencePieceTokenizer(model_path=tokenizer_path) 66 65 67 66 # load dataset splits 68 train_ds = WikiText103Dataset(split='train') 69 test_ds = WikiText103Dataset(split='test') 67 train_ds = load_dataset('wikitext', 'wikitext-2-v1', split='train', streaming=False) 68 test_ds = load_dataset('wikitext', 'wikitext-2-v1', split='test', streaming=False) 69 70 # toknize splits 65 64 tok = SentencePieceTokenizer(model_path=tokenizer_path) 66 65 67 66 # load dataset splits 68 train_ds = WikiText103Dataset(split='train') 69 test_ds = WikiText103Dataset(split='test') 67 train_ds = load_dataset('wikitext', 'wikitext-2-v1', split='train', streaming=False) 68 test_ds = load_dataset('wikitext', 'wikitext-2-v1', split='test', streaming=False) - Comment on lines +67 to +68
Remove
streaming
; the default isFalse
, and mentioning it here doesn't help us too much anyway, because we can't set it toTrue
either way (setting it toTrue
doesn't work because we need to get its__len__()
in the DataLoader; however, since it is anIterableDataset
, it doesn't have a valid__len__()
).
Please register or sign in to reply