Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MTF dataset and packing #293

Merged
merged 183 commits into from
Jul 2, 2022
Merged
Show file tree
Hide file tree
Changes from 181 commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
1fd8099
added train script but with prefix manually declared
May 7, 2022
fe3f8c2
made new dataset
May 9, 2022
31f2087
minor adjustments
May 9, 2022
6ccacba
added capabilities for padding and prefix lm index
lintangsutawika May 9, 2022
155a8ef
added finetune script
lintangsutawika May 9, 2022
ec6c07e
removed script
lintangsutawika May 9, 2022
435d65f
added adjustments and new dataset
May 9, 2022
b1d7bbd
try mlm dataset
May 9, 2022
714e5b7
minor changes
May 9, 2022
dc15436
minor addition of import packages
May 9, 2022
e79ac16
minor error fix
May 9, 2022
9a90a2e
minor error fix
May 9, 2022
7e79b48
samples follow how gpt dataset is loaded
May 9, 2022
3453dbd
added masked_lm_prob
May 9, 2022
d382d19
added mask id
May 9, 2022
31fbf55
added mask id
May 9, 2022
f571132
added mask id
May 9, 2022
5548a47
added mask id
May 9, 2022
21237f3
added fix
May 9, 2022
98c4635
added bos and eos token id
May 9, 2022
e1a75aa
no need for sentinal token
May 9, 2022
2fdd795
add aux functions
May 9, 2022
154f39c
add aux functions
May 9, 2022
3765a81
add aux functions
May 9, 2022
2cd4174
add pad_id
May 9, 2022
b592ea3
changed lm predictions to t5
May 18, 2022
852faca
changed lm predictions to t5
May 18, 2022
4333554
changed lm predictions to t5
May 18, 2022
be73455
changed lm predictions to t5
May 18, 2022
163c966
changed lm predictions to t5
May 18, 2022
56de89f
tokenizer add mask, cls, sep tokens
May 18, 2022
ca86fa8
commit latest changes
May 21, 2022
5011d99
commit latest changes
May 21, 2022
1b15263
added sentinal tokens
May 21, 2022
1b16541
added sentinal tokens
May 21, 2022
0603aac
added sentinal tokens
May 21, 2022
bd061d3
added additional_special_tokens
May 21, 2022
aff88b9
added additional_special_tokens
May 21, 2022
aab4729
check t5_input and output
May 21, 2022
9448ef4
check decoder in and decoder out
May 21, 2022
16ba4aa
made into input and output tokens
May 22, 2022
99ca9e8
made into input and output tokens
May 22, 2022
cdfecad
made into input and output tokens
May 22, 2022
e058688
made into input and output tokens
May 22, 2022
38ded72
made into input and output tokens
May 22, 2022
0f68be3
made into input and output tokens
May 22, 2022
eb84844
made into input and output tokens
May 22, 2022
a3af6bf
made into input and output tokens
May 23, 2022
6ad61b6
made into input and output tokens
May 23, 2022
9131fdd
added eos
May 23, 2022
cb76cd3
added eos
May 23, 2022
531ee68
test text_token
May 24, 2022
a7d1158
test text_token
May 24, 2022
0008cfb
test text_token
May 24, 2022
f1461a8
test text_token
May 24, 2022
ada0f10
test text_token
May 24, 2022
298c9b7
assigned array
May 24, 2022
d2bdff6
assigned array
May 24, 2022
4ec8db3
assigned array
May 24, 2022
10a2b6d
hardcoded sequence length
May 24, 2022
a373a70
check again
May 28, 2022
bdef71b
show sentinal tokens
lintangsutawika May 28, 2022
262fd6c
show sentinal tokens
lintangsutawika May 28, 2022
68a6a93
show sentinal tokens
lintangsutawika May 28, 2022
1c00d4b
show sentinal tokens
lintangsutawika May 28, 2022
8b85f11
add more special tokens
lintangsutawika May 28, 2022
85d204a
changed how mlm data is loaded
lintangsutawika May 28, 2022
4c84274
changed how mlm data is loaded
lintangsutawika May 28, 2022
084245e
changed how mlm data is loaded
lintangsutawika May 28, 2022
32af10e
changed how mlm data is loaded
lintangsutawika May 28, 2022
b6e0e63
changed how mlm data is loaded
lintangsutawika May 28, 2022
2af2e4b
added new script
lintangsutawika May 28, 2022
cc5968e
added new script
lintangsutawika May 28, 2022
cf0b2a0
added new script
lintangsutawika May 28, 2022
fc150a0
try t5 dataset
lintangsutawika May 28, 2022
039f90f
try t5 dataset
lintangsutawika May 28, 2022
7364781
try t5 dataset
lintangsutawika May 28, 2022
5b1100a
try t5 dataset
lintangsutawika May 28, 2022
45102a9
try t5 dataset
lintangsutawika May 28, 2022
7b2ebbf
try t5 dataset
lintangsutawika May 28, 2022
fe8b3dc
try t5 dataset
lintangsutawika May 28, 2022
f456725
try t5 dataset
lintangsutawika May 28, 2022
ae73d8c
try t5 dataset
lintangsutawika May 28, 2022
fae6a0b
try t5 dataset
lintangsutawika May 28, 2022
8185842
try t5 dataset
lintangsutawika May 28, 2022
9deef49
try t5 dataset
lintangsutawika May 28, 2022
1e78a4b
developing
lintangsutawika May 28, 2022
9070929
developing
lintangsutawika May 28, 2022
56c69de
developing
lintangsutawika May 28, 2022
d1ca914
developing
lintangsutawika May 28, 2022
13af623
developing
lintangsutawika May 28, 2022
dbc555e
developing
lintangsutawika May 28, 2022
12b209d
developing
lintangsutawika May 28, 2022
698eff0
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
dae3cc6
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
5c109c3
add new script
May 29, 2022
2fc9995
add new script
May 29, 2022
ee7af99
add new script
May 29, 2022
b6701a8
changed settings
May 30, 2022
2283e58
changed settings
May 30, 2022
9d00a49
tidy up
May 31, 2022
0298fde
changed tokenizer and position embedding
May 31, 2022
bde07f0
modifying mlm to reflect original implementation
Jun 2, 2022
4c0ca2e
minor fix
Jun 2, 2022
0c05596
minor fix
Jun 2, 2022
30f6924
minor fix
Jun 2, 2022
84408ef
minor fix
Jun 2, 2022
ad964c5
minor fix
Jun 2, 2022
45899e9
minor fix
Jun 2, 2022
0b94597
minor fix
Jun 2, 2022
2b54cc1
minor fix
Jun 2, 2022
ec61627
minor fix
Jun 2, 2022
4448d1d
minor fix
Jun 2, 2022
ecd148c
minor fix
Jun 2, 2022
a99f30f
minor fix
Jun 2, 2022
62d3e3e
minor fix
Jun 2, 2022
a160853
minor fix
Jun 2, 2022
fe205f7
minor fix
Jun 2, 2022
d39bdaf
minor fix
Jun 2, 2022
2530d3e
minor fix
Jun 2, 2022
5e93c47
minor fix
Jun 2, 2022
ad86799
minor fix
Jun 2, 2022
82c8d93
minor fix
Jun 2, 2022
ebf3561
minor fix
Jun 2, 2022
811f975
minor fix
Jun 2, 2022
de7dfc8
minor fix
Jun 2, 2022
be2af77
minor fix
Jun 2, 2022
5e7e18f
minor fix
Jun 2, 2022
24d4f25
minor fix
Jun 2, 2022
5926be1
minor fix
Jun 2, 2022
0f18174
minor fix
Jun 2, 2022
58ce714
minor fix
Jun 2, 2022
05470d7
set correct seq len
Jun 2, 2022
51a23f2
refined sampling method
Jun 8, 2022
43cb2f0
refined sampling method
Jun 8, 2022
901defc
refined sampling method
Jun 8, 2022
3130d7d
refined sampling method
Jun 8, 2022
18eb53d
refined sampling method
Jun 8, 2022
652c545
refined sampling method
Jun 8, 2022
5a49db8
first commit, adding non causal mlm dataset
Jun 8, 2022
81b918c
fixed mlm dataset
Jun 8, 2022
95afc4f
fixed mlm dataset
Jun 8, 2022
c4514d8
fixed mlm dataset
Jun 8, 2022
5cca5af
fixed mlm dataset
Jun 8, 2022
ae95878
fixed mlm dataset
Jun 8, 2022
a03e59f
minor changes
Jun 14, 2022
fa1e072
removed mlm related scripts
Jun 22, 2022
e3ce0a7
removed any scipts not related to dataset, revert arguments
Jun 22, 2022
87e4055
added sampler and test
Jun 23, 2022
0ae7661
added testing data
Jun 23, 2022
71fb5ae
adapted test loader
Jun 23, 2022
be0cea2
Update megatron/data/non_causal_mtf_dataset.py
Jun 24, 2022
9daa376
removed unused files
Jun 24, 2022
6b9e81a
changed with impossible token
Jun 24, 2022
7feec27
enable loading multiple indexed_dataset for each field
Jun 24, 2022
f84f293
minor fix
Jun 24, 2022
2778d8d
data_prefix is set as dict
Jun 24, 2022
61ac4b9
removed sample_idx lines
Jun 24, 2022
62e3fb1
change line from sample_idx to doc_idx
Jun 24, 2022
cb79f09
replace shuffling _build_index_mappings with random.sample of the doc…
Jun 25, 2022
e9cf22a
minor changes
Jun 25, 2022
acd87cd
Cleanup artefacts
Muennighoff Jun 27, 2022
019ed7c
Add packed preprocessing
Muennighoff Jun 28, 2022
7619f7a
Use seq_length arg
Muennighoff Jun 28, 2022
219209a
Add sources & docstrings
Muennighoff Jun 28, 2022
67424d6
added training process for t0
Jun 29, 2022
a7c424e
Update pretrain_t0.py
Jun 29, 2022
51d6c40
Remove a bunch of code that's not needed
thomasw21 Jun 29, 2022
b4e374c
WIP
thomasw21 Jun 30, 2022
0d2fdfd
Cleanup
thomasw21 Jun 30, 2022
126fa34
Add back all configs
thomasw21 Jun 30, 2022
83d2405
Woops
thomasw21 Jun 30, 2022
c93ed5c
Fix tests
thomasw21 Jun 30, 2022
528f5d3
Rename testing files
thomasw21 Jun 30, 2022
8bed302
Do in-place operations
thomasw21 Jun 30, 2022
bd2fede
Do in-place operations
thomasw21 Jun 30, 2022
8593e42
Woops
thomasw21 Jun 30, 2022
a1eb558
Fix typo
thomasw21 Jun 30, 2022
3bddafa
Add test that packing is done optimially via greedy algorithm
thomasw21 Jun 30, 2022
45c9444
Woops
thomasw21 Jun 30, 2022
c74dbb7
Update tests/test_dataloaders.py
thomasw21 Jun 30, 2022
a337563
Last segment is either pad or another document that fit perfectly
thomasw21 Jun 30, 2022
6227232
Resolve conflict
thomasw21 Jul 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _add_training_args(parser):
'please refer https://github.com/facebookresearch/bitsandbytes.',
dest='use_bnb_optimizer')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
choices=['single', 'cyclic', 'decoder_packed'],
help='Single pass vs multiple pass data loader')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
Expand Down
176 changes: 168 additions & 8 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,77 @@

"""Dataloaders."""

from functools import partial

import numpy as np
import torch
import random
from megatron import get_args

from megatron import get_args, get_tokenizer
from megatron import mpu
from megatron.data.mtf_dataset import MTFDataset


def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
"""
Greedily packs samples.
thomasw21 marked this conversation as resolved.
Show resolved Hide resolved

Items:
[
{
'input_tokens': array([6, 7]),
'target_tokens': array([8])
},
{
'input_tokens': array([3, 4]),
'target_tokens': array([5])
}
]

Output:
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
"""

decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
thomasw21 marked this conversation as resolved.
Show resolved Hide resolved
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))

batch_num = 0
# `0` is reserved for padding
item_num = 1
cur_len = 0
for token_dict in items:
input_token_len = len(token_dict["input_tokens"])
target_token_len = len(token_dict["target_tokens"])
total_len = input_token_len + target_token_len
if cur_len + total_len > max_seq_len:
len_diff = max_seq_len - cur_len
# Padding
if len_diff > 0:
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
batch_num += 1
assert batch_num < micro_batch_size
item_num = 1
cur_len = 0

decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target

item_num += 1
cur_len += total_len
assert cur_len < max_seq_len

return {
"decoder_target_tokens": decoder_target_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_causal_attention": decoder_causal_attention,
}
Comment on lines +84 to +88
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to cast them to torch tensors / directly use torch tensors instead of numpy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it in the next one.



def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
Expand All @@ -44,18 +110,39 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
sequence_length=args.seq_length + 1,
dataset=dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
args.dataloader_type))

if num_workers is None:
num_workers = args.num_workers

collate_fn = None
if args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
pad_token = get_tokenizer().pad
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
pad_token=pad_token)

# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True
)


class MegatronPretrainingSampler:

Expand Down Expand Up @@ -141,7 +228,7 @@ def __iter__(self):

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

Expand All @@ -158,3 +245,76 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []


class MegatronDecoderPackedText2TextRandomSampler(object):
"""
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
packed to be passed onto the decoder model.

To be used with `pack_samples` as collate_fn
"""

def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.dataset = dataset
self.sequence_length = sequence_length
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)

def __len__(self):
return self.total_samples

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought about it very clearly, but essentially this mean that the entire dataset can be random accessed which isn't possible if you pack things greedily (the problem is given an index, can you know which original sample you need to add in the batches ....) So I'm refactoring this to put this in another dataset ... Sorry about this @Muennighoff you were right.

idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

batch = []
batch_count = 0
token_lens = 0
# Last batch if not complete will be dropped.
for idx in idx_range:
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
if token_lens + tok_len > self.sequence_length:
Copy link
Member Author

@thomasw21 thomasw21 Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a good opportunity to have the EOS token that we didn't have from pretrianing is adding one here.

{INPUT} {TARGET} <EOS>

Not sure we should not have a between input and target
WDYT @TevenLeScao @Muennighoff @lintangsutawika

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!
I think we should add an EOS token only after each target

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree on only after target.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

batch_count += 1
token_lens = 0

if batch_count == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch_count = 0
batch = []
else:
token_lens += tok_len
batch.append(idx)
Loading