diff --git a/optimus/datasets/tinystories.py b/optimus/datasets/tinystories.py index f84b0fe67aed189c85d8b2519661b70e325ebeb6..4928a20bda06587f0ce86cedce6f1e9e32f40ad8 100644 --- a/optimus/datasets/tinystories.py +++ b/optimus/datasets/tinystories.py @@ -74,25 +74,31 @@ class TinyStoriesDataset(Dataset): else: print(f"Found dataset at '{path}'. Using this for '{split}' split...") - # open the dataset and read the splits from it + self.stories = [] + + # open the dataset file and read the stories from it with open(path, 'r') as file: - # read all the lines from the file, split into stories, remove empty - # lines and strip whitespace - text = file.read() - entries = text.split('<|endoftext|>') - entries = [*filter(lambda entry: True if entry != ' \n' else False, entries)] - entries = [entry.strip() for entry in entries] + story = [] + + for line in file: + + if line == '<|endoftext|>\n': + # found the beginning of a story; save the previous one and + # begin building a new story + self.stories.append(' '.join(story)) + story = [] - train_test_split = int(len(entries) * 0.95) - if split == 'train': - self.stories = entries[:train_test_split] + else: + # append the line to the story + story.append(line) - if split == 'test': - self.stories = entries[train_test_split:] + train_test_split = int(0.95 * len(self.stories)) - if split == 'valid': - self.stories = entries + if split == 'train': + self.stories = self.stories[:train_test_split] + elif split == 'test': + self.stories = self.stories[train_test_split:] def __len__(self) -> int: """ @@ -120,5 +126,5 @@ if __name__=='__main__': print(f"Sample for '{split}' split:") data_loader = DataLoader(dataset, batch_size=1, shuffle=True) for data in data_loader: - print(data) + print(''.join(data)) break diff --git a/optimus/datasets/wikitext103.py b/optimus/datasets/wikitext103.py index daaecc91c96c5d2896ab45f0b6fd2f6044c0cefe..2743c61e7b7f4b5b293d16e776f22ba9fff05f51 100644 --- a/optimus/datasets/wikitext103.py +++ b/optimus/datasets/wikitext103.py @@ -64,32 +64,31 @@ class WikiText103Dataset(Dataset): else: print(f"Found dataset at '{path}'. Using this for '{split}' split...") - # open the dataset and read the splits from it - with open(path, 'r') as file: + self.articles = [] - # read all the lines from the file, remove empty lines and strip - # whitespace - lines = file.readlines() - lines = [*filter(lambda line: True if line != ' \n' else False, lines)] - lines = [line.strip() for line in lines] + # we'll read the file line by line, but what we really want to do is + # have a whole article pieced together; this makes it easier to shuffle + # the data for training, while still maintaining information intact + pattern = re.compile(r'^ = [^=]+ = $') - # we're reading lines here, but what we really want to do is have a - # whole article per line; this makes it easier to shuffle the data - # for training, while still maintaining information intact - pattern = re.compile(r'^= [^=]+ =$') + # open the dataset file and read the articles from it + with open(path, 'r') as file: - self.articles = [] article = [] - for line in lines: + + for line in file: + if pattern.match(line): # found the beginning of an article; save the previous one # and begin building a new article self.articles.append(' '.join(article)) article = [] + + # append the line to the article article.append(line) - # delete the first empty article - del self.articles[0] + # delete the first empty article + del self.articles[0] def __len__(self) -> int: """ @@ -117,5 +116,5 @@ if __name__=='__main__': print(f"Sample for '{split}' split:") data_loader = DataLoader(dataset, batch_size=1, shuffle=True) for data in data_loader: - print(data) + print(''.join(data)) break diff --git a/optimus/example_training.py b/optimus/example_training.py index b3619a242703fd734541fa6f227f7dc398098c84..9bfed385050375712931ced5b290d8d8c2f0d390 100644 --- a/optimus/example_training.py +++ b/optimus/example_training.py @@ -68,8 +68,8 @@ def main(batch_size: int = 8, train_ds = WikiText103Dataset(split='train') test_ds = WikiText103Dataset(split='test') - print(f"Number of Wikipedia articles for train set: {len(train_ds)}") - print(f"Number of Wikipedia articles for test set: {len(test_ds)}") + print(f"Number of examples in training set: {len(train_ds)}") + print(f"Number of examples in testing set: {len(test_ds)}") # create dataloader object and move to device dl = OptimusDataLoader(train_ds, test_ds, tok,