Skip to content
Snippets Groups Projects
Commit c526355a authored by Vlad-Andrei BĂDOIU (78692)'s avatar Vlad-Andrei BĂDOIU (78692)
Browse files

Merge branch 'fix/datasets' into 'main'

Fix datasets memory issues

See merge request !9
parents ba87b2ba ee95daff
No related branches found
No related tags found
1 merge request!9Fix datasets memory issues
......@@ -74,25 +74,31 @@ class TinyStoriesDataset(Dataset):
else:
print(f"Found dataset at '{path}'. Using this for '{split}' split...")
# open the dataset and read the splits from it
self.stories = []
# open the dataset file and read the stories from it
with open(path, 'r') as file:
# read all the lines from the file, split into stories, remove empty
# lines and strip whitespace
text = file.read()
entries = text.split('<|endoftext|>')
entries = [*filter(lambda entry: True if entry != ' \n' else False, entries)]
entries = [entry.strip() for entry in entries]
story = []
for line in file:
if line == '<|endoftext|>\n':
# found the beginning of a story; save the previous one and
# begin building a new story
self.stories.append(' '.join(story))
story = []
train_test_split = int(len(entries) * 0.95)
if split == 'train':
self.stories = entries[:train_test_split]
else:
# append the line to the story
story.append(line)
if split == 'test':
self.stories = entries[train_test_split:]
train_test_split = int(0.95 * len(self.stories))
if split == 'valid':
self.stories = entries
if split == 'train':
self.stories = self.stories[:train_test_split]
elif split == 'test':
self.stories = self.stories[train_test_split:]
def __len__(self) -> int:
"""
......@@ -120,5 +126,5 @@ if __name__=='__main__':
print(f"Sample for '{split}' split:")
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
for data in data_loader:
print(data)
print(''.join(data))
break
......@@ -64,32 +64,31 @@ class WikiText103Dataset(Dataset):
else:
print(f"Found dataset at '{path}'. Using this for '{split}' split...")
# open the dataset and read the splits from it
with open(path, 'r') as file:
self.articles = []
# read all the lines from the file, remove empty lines and strip
# whitespace
lines = file.readlines()
lines = [*filter(lambda line: True if line != ' \n' else False, lines)]
lines = [line.strip() for line in lines]
# we'll read the file line by line, but what we really want to do is
# have a whole article pieced together; this makes it easier to shuffle
# the data for training, while still maintaining information intact
pattern = re.compile(r'^ = [^=]+ = $')
# we're reading lines here, but what we really want to do is have a
# whole article per line; this makes it easier to shuffle the data
# for training, while still maintaining information intact
pattern = re.compile(r'^= [^=]+ =$')
# open the dataset file and read the articles from it
with open(path, 'r') as file:
self.articles = []
article = []
for line in lines:
for line in file:
if pattern.match(line):
# found the beginning of an article; save the previous one
# and begin building a new article
self.articles.append(' '.join(article))
article = []
# append the line to the article
article.append(line)
# delete the first empty article
del self.articles[0]
# delete the first empty article
del self.articles[0]
def __len__(self) -> int:
"""
......@@ -117,5 +116,5 @@ if __name__=='__main__':
print(f"Sample for '{split}' split:")
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
for data in data_loader:
print(data)
print(''.join(data))
break
......@@ -68,8 +68,8 @@ def main(batch_size: int = 8,
train_ds = WikiText103Dataset(split='train')
test_ds = WikiText103Dataset(split='test')
print(f"Number of Wikipedia articles for train set: {len(train_ds)}")
print(f"Number of Wikipedia articles for test set: {len(test_ds)}")
print(f"Number of examples in training set: {len(train_ds)}")
print(f"Number of examples in testing set: {len(test_ds)}")
# create dataloader object and move to device
dl = OptimusDataLoader(train_ds, test_ds, tok,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment