Skip to content

Commit

Permalink
MTF dataset and packing (#293)
Browse files Browse the repository at this point in the history
Co-authored-by: Lintang Sutawika <lintang@datasaur.ai>
Co-authored-by: lintangsutawika <lintang@sutawika.com>
Co-authored-by: Muennighoff <n.muennighoff@gmail.com>
  • Loading branch information
4 people authored Jul 2, 2022
1 parent 131bd43 commit c5b88fb
Show file tree
Hide file tree
Showing 10 changed files with 693 additions and 15 deletions.
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _add_training_args(parser):
'please refer https://github.com/facebookresearch/bitsandbytes.',
dest='use_bnb_optimizer')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
choices=['single', 'cyclic', 'decoder_packed'],
help='Single pass vs multiple pass data loader')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
Expand Down
176 changes: 168 additions & 8 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,77 @@

"""Dataloaders."""

from functools import partial

import numpy as np
import torch
import random
from megatron import get_args

from megatron import get_args, get_tokenizer
from megatron import mpu
from megatron.data.mtf_dataset import MTFDataset


def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
"""
Greedily packs samples.
Items:
[
{
'input_tokens': array([6, 7]),
'target_tokens': array([8])
},
{
'input_tokens': array([3, 4]),
'target_tokens': array([5])
}
]
Output:
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
"""

decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))

batch_num = 0
# `0` is reserved for padding
item_num = 1
cur_len = 0
for token_dict in items:
input_token_len = len(token_dict["input_tokens"])
target_token_len = len(token_dict["target_tokens"])
total_len = input_token_len + target_token_len
if cur_len + total_len > max_seq_len:
len_diff = max_seq_len - cur_len
# Padding
if len_diff > 0:
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
batch_num += 1
assert batch_num < micro_batch_size
item_num = 1
cur_len = 0

decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target

item_num += 1
cur_len += total_len
assert cur_len < max_seq_len

return {
"decoder_target_tokens": decoder_target_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_causal_attention": decoder_causal_attention,
}


def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
Expand All @@ -44,18 +110,39 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
sequence_length=args.seq_length + 1,
dataset=dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
args.dataloader_type))

if num_workers is None:
num_workers = args.num_workers

collate_fn = None
if args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
pad_token = get_tokenizer().pad
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
pad_token=pad_token)

# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True
)


class MegatronPretrainingSampler:

Expand Down Expand Up @@ -141,7 +228,7 @@ def __iter__(self):

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

Expand All @@ -158,3 +245,76 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []


class MegatronDecoderPackedText2TextRandomSampler(object):
"""
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
packed to be passed onto the decoder model.
To be used with `pack_samples` as collate_fn
"""

def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.dataset = dataset
self.sequence_length = sequence_length
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)

def __len__(self):
return self.total_samples

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

batch = []
batch_count = 0
token_lens = 0
# Last batch if not complete will be dropped.
for idx in idx_range:
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
if token_lens + tok_len > self.sequence_length:
batch_count += 1
token_lens = 0

if batch_count == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch_count = 0
batch = []
else:
token_lens += tok_len
batch.append(idx)
Loading

0 comments on commit c5b88fb

Please sign in to comment.