-
Notifications
You must be signed in to change notification settings - Fork 220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MTF dataset and packing #293
Conversation
tests/test_dataloaders.py
Outdated
# `segment_ids` is [1,2,...] | ||
self.assertEqual(segment_ids[:-1], list(range(1, len(segment_ids)))) | ||
# `0` signify that the tokens are padding | ||
self.assertEqual(segment_ids[-1], 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the last segment id have to be padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I'll fix them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in c74dbb7
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
return { | ||
"decoder_target_tokens": decoder_target_tokens, | ||
"decoder_segment_ids": decoder_segment_ids, | ||
"decoder_causal_attention": decoder_causal_attention, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to cast them to torch tensors / directly use torch tensors instead of numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it in the next one.
skip_warmup=(not args.mmap_warmup) | ||
) | ||
|
||
# TODO @thomasw21 make sure that input and target are aligned. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We will check this when preprocessing
# Last batch if not complete will be dropped. | ||
for idx in idx_range: | ||
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens']) | ||
if token_lens + tok_len > self.sequence_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a good opportunity to have the EOS token that we didn't have from pretrianing is adding one here.
{INPUT} {TARGET} <EOS>
Not sure we should not have a between input and target
WDYT @TevenLeScao @Muennighoff @lintangsutawika
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
I think we should add an EOS token only after each target
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agree on only after target.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
g = torch.Generator() | ||
g.manual_seed(self.epoch) | ||
|
||
random_idx = torch.randperm(bucket_size, generator=g).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't thought about it very clearly, but essentially this mean that the entire dataset can be random accessed which isn't possible if you pack things greedily (the problem is given an index, can you know which original sample you need to add in the batches ....) So I'm refactoring this to put this in another dataset ... Sorry about this @Muennighoff you were right.
FYI, I've added a skip to not run it by CI. #302 I partially fixed it while at it - the test couldn't have possibly ever worked as the args were wrong After the partial fix so that it actually runs it fails on CI and on my desktop as well:
|
Co-authored-by: Lintang Sutawika <lintang@datasaur.ai> Co-authored-by: lintangsutawika <lintang@sutawika.com> Co-authored-by: Muennighoff <n.muennighoff@gmail.com>
No description provided.