Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MTF dataset and packing #293

Merged
merged 183 commits into from
Jul 2, 2022
Merged

MTF dataset and packing #293

merged 183 commits into from
Jul 2, 2022

Conversation

thomasw21
Copy link
Member

No description provided.

@thomasw21 thomasw21 changed the title MTF dataset and packing [WIP] MTF dataset and packing Jun 30, 2022
@thomasw21 thomasw21 requested a review from Muennighoff June 30, 2022 13:24
@thomasw21 thomasw21 changed the title [WIP] MTF dataset and packing MTF dataset and packing Jun 30, 2022
megatron/data/mtf_dataset.py Show resolved Hide resolved
# `segment_ids` is [1,2,...]
self.assertEqual(segment_ids[:-1], list(range(1, len(segment_ids))))
# `0` signify that the tokens are padding
self.assertEqual(segment_ids[-1], 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the last segment id have to be padding?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I'll fix them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c74dbb7

tests/test_dataloaders.py Show resolved Hide resolved
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
@thomasw21 thomasw21 requested a review from Muennighoff June 30, 2022 16:31
Comment on lines +84 to +88
return {
"decoder_target_tokens": decoder_target_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_causal_attention": decoder_causal_attention,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to cast them to torch tensors / directly use torch tensors instead of numpy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it in the next one.

skip_warmup=(not args.mmap_warmup)
)

# TODO @thomasw21 make sure that input and target are aligned.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We will check this when preprocessing

@thomasw21 thomasw21 merged commit c5b88fb into main Jul 2, 2022
@thomasw21 thomasw21 deleted the thomas/mtf_dataset_and_packing branch July 2, 2022 04:16
# Last batch if not complete will be dropped.
for idx in idx_range:
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
if token_lens + tok_len > self.sequence_length:
Copy link
Member Author

@thomasw21 thomasw21 Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a good opportunity to have the EOS token that we didn't have from pretrianing is adding one here.

{INPUT} {TARGET} <EOS>

Not sure we should not have a between input and target
WDYT @TevenLeScao @Muennighoff @lintangsutawika

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!
I think we should add an EOS token only after each target

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree on only after target.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought about it very clearly, but essentially this mean that the entire dataset can be random accessed which isn't possible if you pack things greedily (the problem is given an index, can you know which original sample you need to add in the batches ....) So I'm refactoring this to put this in another dataset ... Sorry about this @Muennighoff you were right.

@stas00
Copy link
Contributor

stas00 commented Jul 4, 2022

FYI, tests/test_dataloaders.py::TestDataLoading::test_mlm_dataset fails.

I've added a skip to not run it by CI. #302
Please fix the test at your convenience.

I partially fixed it while at it - the test couldn't have possibly ever worked as the args were wrong

After the partial fix so that it actually runs it fails on CI and on my desktop as well:

>               train_ds, valid_ds, test_ds = mlm_dataset.build_train_valid_test_datasets(
                    data_prefix=args.data_path,
                    data_impl=args.data_impl,
                    splits_string=args.split,
                    # TODO @thomasw21 figure how that value works
                    train_valid_test_num_samples=train_val_test_num_samples,
                    sequence_length=args.seq_length,
                    noise_density=args.noise_density,
                    mean_noise_span_length=args.mean_noise_span_length,
                    seed=args.seed,
                    skip_warmup=(not args.mmap_warmup)
                )

tests/test_dataloaders.py:90: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
megatron/data/mlm_dataset.py:25: in build_train_valid_test_datasets
    return _build_train_valid_test_datasets(
megatron/data/mlm_dataset.py:263: in _build_train_valid_test_datasets
    train_dataset = build_dataset(0, 'train')
megatron/data/mlm_dataset.py:250: in build_dataset
    dataset = MLMDataset(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <megatron.data.mlm_dataset.MLMDataset object at 0x7fc94710c070>, name = 'train'
indexed_dataset = <megatron.data.indexed_dataset.MMapIndexedDataset object at 0x7fc94710c880>
documents = array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,...9, 930, 931, 932, 933, 934, 935,
       936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947],
      dtype=int32)
data_prefix = '/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master/tests/data/gpt2/meg-gpt2-openwebtext_text_document'
sequence_length = 512, num_samples = 40000, seed = 1234, noise_density = 0.15, mean_noise_span_length = 3

    def __init__(
        self,
        name,
        indexed_dataset,
        documents,
        data_prefix,
        sequence_length,
        num_samples,
        seed,
        noise_density=0.15,
        mean_noise_span_length=3
    ):
    
        # Params to store.
        self.name = name
        self.seed = seed
        self.sequence_length = sequence_length
    
        # Dataset.
        self.indexed_dataset = indexed_dataset
    
        self.noise_density = noise_density
        self.mean_noise_span_length = mean_noise_span_length
        # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
        # To ensure that the input length is `sequence_length`, we need to increase the maximum length
        # according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
        number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths(
            sequence_length=self.sequence_length,
            noise_density=self.noise_density,
            mean_noise_span_length=self.mean_noise_span_length
        )
        self.inputs_length = inputs_length
        # In order to compute loss, we need an extra token at the end.
        self.number_of_raw_tokens = number_of_raw_tokens + 1
        self.targets_length = targets_length + 1
        self.num_noise_spans = num_noise_spans
    
        # Build the samples mapping.
        self._gpt_dataset = GPTDataset(
            name=self.name,
            data_prefix=data_prefix,
            documents=documents,
            indexed_dataset=self.indexed_dataset,
            num_samples=num_samples,
            # -1 because GPTDataset will return `seq_length + 1` sequences.
            seq_length=number_of_raw_tokens - 1,
            seed=seed
        )
    
        # Vocab stuff.
        tokenizer = get_tokenizer()
        self.sep_id = tokenizer.sep
        self.sentinel_token_ids = tokenizer.additional_special_tokens_ids
        assert self.sep_id is not None, "MLM dataset requires tokenizer to have a <sep> token"
        assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
        assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"
    
        args = get_args()
        if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None:
            # T5 style
>           assert self.inputs_length == args.encoder_seq_length
E           AssertionError

megatron/data/mlm_dataset.py:332: AssertionError

younesbelkada pushed a commit to younesbelkada/Megatron-DeepSpeed that referenced this pull request Sep 28, 2022
Co-authored-by: Lintang Sutawika <lintang@datasaur.ai>
Co-authored-by: lintangsutawika <lintang@sutawika.com>
Co-authored-by: Muennighoff <n.muennighoff@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants