Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLM] Train script for non causal decoder #300

Draft
wants to merge 298 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
298 commits
Select commit Hold shift + click to select a range
6ad61b6
made into input and output tokens
May 23, 2022
9131fdd
added eos
May 23, 2022
cb76cd3
added eos
May 23, 2022
531ee68
test text_token
May 24, 2022
a7d1158
test text_token
May 24, 2022
0008cfb
test text_token
May 24, 2022
f1461a8
test text_token
May 24, 2022
ada0f10
test text_token
May 24, 2022
298c9b7
assigned array
May 24, 2022
d2bdff6
assigned array
May 24, 2022
4ec8db3
assigned array
May 24, 2022
10a2b6d
hardcoded sequence length
May 24, 2022
a373a70
check again
May 28, 2022
bdef71b
show sentinal tokens
lintangsutawika May 28, 2022
262fd6c
show sentinal tokens
lintangsutawika May 28, 2022
68a6a93
show sentinal tokens
lintangsutawika May 28, 2022
1c00d4b
show sentinal tokens
lintangsutawika May 28, 2022
8b85f11
add more special tokens
lintangsutawika May 28, 2022
85d204a
changed how mlm data is loaded
lintangsutawika May 28, 2022
4c84274
changed how mlm data is loaded
lintangsutawika May 28, 2022
084245e
changed how mlm data is loaded
lintangsutawika May 28, 2022
32af10e
changed how mlm data is loaded
lintangsutawika May 28, 2022
b6e0e63
changed how mlm data is loaded
lintangsutawika May 28, 2022
2af2e4b
added new script
lintangsutawika May 28, 2022
cc5968e
added new script
lintangsutawika May 28, 2022
cf0b2a0
added new script
lintangsutawika May 28, 2022
fc150a0
try t5 dataset
lintangsutawika May 28, 2022
039f90f
try t5 dataset
lintangsutawika May 28, 2022
7364781
try t5 dataset
lintangsutawika May 28, 2022
5b1100a
try t5 dataset
lintangsutawika May 28, 2022
45102a9
try t5 dataset
lintangsutawika May 28, 2022
7b2ebbf
try t5 dataset
lintangsutawika May 28, 2022
fe8b3dc
try t5 dataset
lintangsutawika May 28, 2022
f456725
try t5 dataset
lintangsutawika May 28, 2022
ae73d8c
try t5 dataset
lintangsutawika May 28, 2022
fae6a0b
try t5 dataset
lintangsutawika May 28, 2022
8185842
try t5 dataset
lintangsutawika May 28, 2022
9deef49
try t5 dataset
lintangsutawika May 28, 2022
1e78a4b
developing
lintangsutawika May 28, 2022
9070929
developing
lintangsutawika May 28, 2022
56c69de
developing
lintangsutawika May 28, 2022
d1ca914
developing
lintangsutawika May 28, 2022
13af623
developing
lintangsutawika May 28, 2022
dbc555e
developing
lintangsutawika May 28, 2022
12b209d
developing
lintangsutawika May 28, 2022
698eff0
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
dae3cc6
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
5c109c3
add new script
May 29, 2022
2fc9995
add new script
May 29, 2022
ee7af99
add new script
May 29, 2022
b6701a8
changed settings
May 30, 2022
2283e58
changed settings
May 30, 2022
9d00a49
tidy up
May 31, 2022
0298fde
changed tokenizer and position embedding
May 31, 2022
bde07f0
modifying mlm to reflect original implementation
Jun 2, 2022
4c0ca2e
minor fix
Jun 2, 2022
0c05596
minor fix
Jun 2, 2022
30f6924
minor fix
Jun 2, 2022
84408ef
minor fix
Jun 2, 2022
ad964c5
minor fix
Jun 2, 2022
45899e9
minor fix
Jun 2, 2022
0b94597
minor fix
Jun 2, 2022
2b54cc1
minor fix
Jun 2, 2022
ec61627
minor fix
Jun 2, 2022
4448d1d
minor fix
Jun 2, 2022
ecd148c
minor fix
Jun 2, 2022
a99f30f
minor fix
Jun 2, 2022
62d3e3e
minor fix
Jun 2, 2022
a160853
minor fix
Jun 2, 2022
fe205f7
minor fix
Jun 2, 2022
d39bdaf
minor fix
Jun 2, 2022
2530d3e
minor fix
Jun 2, 2022
5e93c47
minor fix
Jun 2, 2022
ad86799
minor fix
Jun 2, 2022
82c8d93
minor fix
Jun 2, 2022
ebf3561
minor fix
Jun 2, 2022
811f975
minor fix
Jun 2, 2022
de7dfc8
minor fix
Jun 2, 2022
be2af77
minor fix
Jun 2, 2022
5e7e18f
minor fix
Jun 2, 2022
24d4f25
minor fix
Jun 2, 2022
5926be1
minor fix
Jun 2, 2022
0f18174
minor fix
Jun 2, 2022
58ce714
minor fix
Jun 2, 2022
05470d7
set correct seq len
Jun 2, 2022
51a23f2
refined sampling method
Jun 8, 2022
43cb2f0
refined sampling method
Jun 8, 2022
901defc
refined sampling method
Jun 8, 2022
3130d7d
refined sampling method
Jun 8, 2022
18eb53d
refined sampling method
Jun 8, 2022
652c545
refined sampling method
Jun 8, 2022
5a49db8
first commit, adding non causal mlm dataset
Jun 8, 2022
81b918c
fixed mlm dataset
Jun 8, 2022
95afc4f
fixed mlm dataset
Jun 8, 2022
c4514d8
fixed mlm dataset
Jun 8, 2022
5cca5af
fixed mlm dataset
Jun 8, 2022
ae95878
fixed mlm dataset
Jun 8, 2022
a03e59f
minor changes
Jun 14, 2022
fa1e072
removed mlm related scripts
Jun 22, 2022
e3ce0a7
removed any scipts not related to dataset, revert arguments
Jun 22, 2022
87e4055
added sampler and test
Jun 23, 2022
0ae7661
added testing data
Jun 23, 2022
71fb5ae
adapted test loader
Jun 23, 2022
be0cea2
Update megatron/data/non_causal_mtf_dataset.py
Jun 24, 2022
9daa376
removed unused files
Jun 24, 2022
6b9e81a
changed with impossible token
Jun 24, 2022
7feec27
enable loading multiple indexed_dataset for each field
Jun 24, 2022
f84f293
minor fix
Jun 24, 2022
2778d8d
data_prefix is set as dict
Jun 24, 2022
61ac4b9
removed sample_idx lines
Jun 24, 2022
62e3fb1
change line from sample_idx to doc_idx
Jun 24, 2022
cb79f09
replace shuffling _build_index_mappings with random.sample of the doc…
Jun 25, 2022
e9cf22a
minor changes
Jun 25, 2022
acd87cd
Cleanup artefacts
Muennighoff Jun 27, 2022
019ed7c
Add packed preprocessing
Muennighoff Jun 28, 2022
7619f7a
Use seq_length arg
Muennighoff Jun 28, 2022
219209a
Add sources & docstrings
Muennighoff Jun 28, 2022
67424d6
added training process for t0
Jun 29, 2022
a7c424e
Update pretrain_t0.py
Jun 29, 2022
51d6c40
Remove a bunch of code that's not needed
thomasw21 Jun 29, 2022
b4e374c
WIP
thomasw21 Jun 30, 2022
0d2fdfd
Cleanup
thomasw21 Jun 30, 2022
126fa34
Add back all configs
thomasw21 Jun 30, 2022
83d2405
Woops
thomasw21 Jun 30, 2022
c93ed5c
Fix tests
thomasw21 Jun 30, 2022
528f5d3
Rename testing files
thomasw21 Jun 30, 2022
8bed302
Do in-place operations
thomasw21 Jun 30, 2022
bd2fede
Do in-place operations
thomasw21 Jun 30, 2022
8593e42
Woops
thomasw21 Jun 30, 2022
a1eb558
Fix typo
thomasw21 Jun 30, 2022
3bddafa
Add test that packing is done optimially via greedy algorithm
thomasw21 Jun 30, 2022
45c9444
Woops
thomasw21 Jun 30, 2022
6f28ae4
added capabilities for padding and prefix lm index
lintangsutawika May 9, 2022
8a4d99b
added adjustments and new dataset
May 9, 2022
ea445b1
added sentinal tokens
May 21, 2022
4070859
made into input and output tokens
May 23, 2022
85e84ec
modifying mlm to reflect original implementation
Jun 2, 2022
3922293
minor fix
Jun 2, 2022
ee6438f
added sampler and test
Jun 23, 2022
a869adf
Enable training
Muennighoff Jun 29, 2022
5ae15ef
Add T0 training test
Muennighoff Jun 30, 2022
efa55ea
Remove artefacts
Muennighoff Jun 30, 2022
f45266d
Remove artefacts
Muennighoff Jun 30, 2022
8029564
WIP
thomasw21 Jun 30, 2022
4faa743
WIP
thomasw21 Jul 1, 2022
3a6d73d
WIP
thomasw21 Jul 1, 2022
ea86bc8
WIP
thomasw21 Jul 1, 2022
638fc56
WIP
thomasw21 Jul 1, 2022
66d2afe
move to cpu for comparison
thomasw21 Jul 1, 2022
3794b86
Use torch_assert_equal
thomasw21 Jul 1, 2022
346b08f
WIP
thomasw21 Jul 1, 2022
4203f6c
Take in account pad + fix inverse
thomasw21 Jul 1, 2022
bcba2b7
Tensor and int can't be compared vi torch_assert_equal
thomasw21 Jul 1, 2022
57156e1
Woops
thomasw21 Jul 1, 2022
45d9218
Test
thomasw21 Jul 1, 2022
959fc71
Woops
thomasw21 Jul 1, 2022
27197fc
Remove unecessary unsqueeze
thomasw21 Jul 1, 2022
b7374e1
Add necessary unsqueeze
thomasw21 Jul 1, 2022
4f6b7d3
I'm stupid
thomasw21 Jul 1, 2022
960b17c
I'm stupid
thomasw21 Jul 1, 2022
2b522d1
Tokenizers returns None when trying to access a non existing value
thomasw21 Jul 1, 2022
a8fcd38
Force gpt2 to have a pad token
thomasw21 Jul 1, 2022
7181de4
Add a test that the packed_masking works in the modeling side
thomasw21 Jul 1, 2022
172306b
Import error
thomasw21 Jul 1, 2022
a4854bd
Tokenizer requires to have pad token
thomasw21 Jul 1, 2022
06c29a9
Turns out that test_model.py did not use deepspeed version of models
thomasw21 Jul 1, 2022
aba48b3
Use train_batch instead
thomasw21 Jul 1, 2022
a9d423a
Make it work via DS
thomasw21 Jul 1, 2022
6a95e25
Make it work via DS
thomasw21 Jul 1, 2022
d6e435b
Make it work via DS
thomasw21 Jul 1, 2022
ca8c04a
Make it work via DS
thomasw21 Jul 1, 2022
f3231db
Make it work via DS
thomasw21 Jul 1, 2022
987e6b4
Make it work via DS
thomasw21 Jul 1, 2022
0b27fb6
Make it work via DS
thomasw21 Jul 1, 2022
1ba5d4a
Woops
thomasw21 Jul 1, 2022
cbab16c
Make it work via DS
thomasw21 Jul 1, 2022
4defbb2
Make it work via DS
thomasw21 Jul 1, 2022
412939c
Make it work via DS
thomasw21 Jul 1, 2022
17a6cc0
Maybe
thomasw21 Jul 1, 2022
cb90679
Make it work via DS
thomasw21 Jul 1, 2022
bd4a3f0
Woops
thomasw21 Jul 1, 2022
6604035
Try having very strict mask
thomasw21 Jul 1, 2022
d98e39a
Try updating the kernel
thomasw21 Jul 1, 2022
8495083
Try updating the kernel
thomasw21 Jul 1, 2022
ef5d4d4
Try updating the kernel
thomasw21 Jul 1, 2022
69912b3
Try updating the kernel
thomasw21 Jul 1, 2022
866fc56
Try updating the kernel
thomasw21 Jul 1, 2022
8e9701b
Try updating the kernel
thomasw21 Jul 1, 2022
15d95fa
Inverse causal masking
thomasw21 Jul 1, 2022
fe4f806
Check that the padding are ignored
thomasw21 Jul 1, 2022
cc2aff5
Fix test
thomasw21 Jul 1, 2022
93cde87
Probably should be in this order:
thomasw21 Jul 1, 2022
f6d717b
Revert "Probably should be in this order:"
thomasw21 Jul 1, 2022
910f93b
Add a test checking that ScaledMaskedSoftmax custom kernel does what …
thomasw21 Jul 1, 2022
75f99ef
Head specific mask is not implemented
thomasw21 Jul 1, 2022
c34f107
Test something out
thomasw21 Jul 2, 2022
ed6131a
Test something out
thomasw21 Jul 2, 2022
3a846a0
Test something out
thomasw21 Jul 2, 2022
5746641
Test something out
thomasw21 Jul 2, 2022
292620c
Test something out
thomasw21 Jul 2, 2022
0e1ef5d
Test something out
thomasw21 Jul 2, 2022
964a275
Test something out
thomasw21 Jul 2, 2022
8b31e9c
Test something out
thomasw21 Jul 2, 2022
723a5b3
Test something out
thomasw21 Jul 2, 2022
65b4ea2
Test something out
thomasw21 Jul 2, 2022
7eaced4
Maybe nothing is wrong
thomasw21 Jul 2, 2022
da9f316
Woops
thomasw21 Jul 2, 2022
8b67bd9
Use bloom instead
thomasw21 Jul 2, 2022
84007bc
Make MTF dataloader an infinite dataloader
thomasw21 Jul 2, 2022
273d420
Work into moving packing logic into a dataset
thomasw21 Jul 2, 2022
688d06e
Woops
thomasw21 Jul 2, 2022
ddc6a61
Woops
thomasw21 Jul 2, 2022
0e34e8d
Woops
thomasw21 Jul 2, 2022
014b8b8
Woops
thomasw21 Jul 2, 2022
c53622a
Woops
thomasw21 Jul 2, 2022
ea221a8
Woops
thomasw21 Jul 2, 2022
3274986
Woops
thomasw21 Jul 2, 2022
9a5bf96
Woops
thomasw21 Jul 2, 2022
d160589
Woops
thomasw21 Jul 2, 2022
c3ab5b9
Woops
thomasw21 Jul 2, 2022
f541076
Woops
thomasw21 Jul 2, 2022
20be5b9
Requires to remember how may epochs
thomasw21 Jul 2, 2022
d9719b6
Find a way to reset states everytime
thomasw21 Jul 2, 2022
4e0c4ca
Find a way to reset states everytime
thomasw21 Jul 2, 2022
48a55b9
Find a way to reset states everytime
thomasw21 Jul 2, 2022
2e469e5
Find a way to reset states everytime
thomasw21 Jul 2, 2022
74e03ec
Find a way to reset states everytime
thomasw21 Jul 2, 2022
f4a4733
Fix bugs
thomasw21 Jul 2, 2022
e1a3767
Cleanup
thomasw21 Jul 2, 2022
efeb55a
Merge remote-tracking branch 'official_repo/main' into thomas/mtf_tra…
thomasw21 Jul 2, 2022
de88ab6
Woops
thomasw21 Jul 2, 2022
d7a6388
Woops
thomasw21 Jul 2, 2022
1c2284f
Woops
thomasw21 Jul 2, 2022
b759a92
Woops
thomasw21 Jul 2, 2022
ef20e57
Woops
thomasw21 Jul 2, 2022
5816adf
Silently skip samples that are too long
thomasw21 Jul 2, 2022
37ad57e
Build the index from scratch everytime
thomasw21 Jul 2, 2022
1572ddc
Prevent empty dataset
thomasw21 Jul 2, 2022
bebb481
Change the condition for empty slice
thomasw21 Jul 2, 2022
5c80699
PR reviews
thomasw21 Jul 3, 2022
985cd02
Revert back changes linked to shutil.copytree
thomasw21 Jul 3, 2022
41e931a
Get test working
thomasw21 Jul 3, 2022
b321a34
Woops
thomasw21 Jul 3, 2022
0450bad
Woops
thomasw21 Jul 3, 2022
de4934f
Fix empty samples
thomasw21 Jul 3, 2022
e3e21f5
Cuda kernel is not strictly equivalent
thomasw21 Jul 3, 2022
16c556c
Update tests/test_model.py
thomasw21 Jul 4, 2022
f2df771
MTF optimize dataloading (#298)
thomasw21 Jul 4, 2022
a45c9cd
Get pretrain on non causal mlm script
thomasw21 Jul 4, 2022
606fdeb
Test
thomasw21 Jul 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Multitask Finetuning T0"""

from multiprocessing.sharedctypes import Value
import torch

from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_attention_masks_and_position_ids, get_packed_attention_mask

import deepspeed
from deepspeed.runtime.utils import see_memory_usage

try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn

def model_provider(pre_process=True, post_process=True):
"""Build the model."""

print_rank_0("building GPT model ...")
see_memory_usage(f"Before Building Model", force=True)

args = get_args()

with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True,
attn_mask_type=AttnMaskType.custom
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
raise NotImplementedError("DeepSpeed is required for T0")

see_memory_usage(f"After Building Model", force=True)
return model

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion

data:
decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]]
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
args = get_args()
tokenizer = get_tokenizer()

# Broadcast data.
data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64)
data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool)

# Unpack.
tokens_ = data_b["decoder_token_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

segment_ids = data_b["decoder_segment_ids"].long()[:, :-1]
decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1]

# Get the masks and position ids.
causal_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=False # This is done below
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

return (tokens, position_ids, attention_mask), (labels, loss_mask)


def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None

tokenizer = get_tokenizer()

print_rank_0("> building train, validation, and test datasets for T0 ...")
# Option 1 of data loading using --data-path
if args.data_path:
# TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
train_valid_test_num_samples=train_val_test_num_samples,
seed=args.seed,
skip_warmup=(not args.mmap_warmup)
)
else:
raise NotImplementedError("No dataloading argument passed")

print_rank_0("> finished creating T0 datasets ...")
return train_ds, valid_ds, test_ds

@record
def main():
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step_func=None,
args_defaults={}
)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _add_training_args(parser):
'please refer https://github.com/facebookresearch/bitsandbytes.',
dest='use_bnb_optimizer')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic', 'decoder_packed'],
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
Expand Down
164 changes: 4 additions & 160 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,77 +15,11 @@

"""Dataloaders."""

from functools import partial

import numpy as np
import torch

from megatron import get_args, get_tokenizer
from megatron import get_args
from megatron import mpu
from megatron.data.mtf_dataset import MTFDataset


def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
"""
Greedily packs samples.

Items:
[
{
'input_tokens': array([6, 7]),
'target_tokens': array([8])
},
{
'input_tokens': array([3, 4]),
'target_tokens': array([5])
}
]

Output:
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
"""

decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))

batch_num = 0
# `0` is reserved for padding
item_num = 1
cur_len = 0
for token_dict in items:
input_token_len = len(token_dict["input_tokens"])
target_token_len = len(token_dict["target_tokens"])
total_len = input_token_len + target_token_len
if cur_len + total_len > max_seq_len:
len_diff = max_seq_len - cur_len
# Padding
if len_diff > 0:
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
batch_num += 1
assert batch_num < micro_batch_size
item_num = 1
cur_len = 0

decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target

item_num += 1
cur_len += total_len
assert cur_len < max_seq_len

return {
"decoder_target_tokens": decoder_target_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_causal_attention": decoder_causal_attention,
}
from megatron.data.decoder_packed_mtf_dataset import DecoderPackedMTFDataset


def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
Expand All @@ -110,40 +44,22 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
sequence_length=args.seq_length + 1,
dataset=dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))

if num_workers is None:
num_workers = args.num_workers

collate_fn = None
if args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
pad_token = get_tokenizer().pad
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
pad_token=pad_token)

# Torch dataloader.
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
collate_fn=None,
pin_memory=True
)


class MegatronPretrainingSampler:

def __init__(self, total_samples, consumed_samples, micro_batch_size,
Expand Down Expand Up @@ -234,6 +150,7 @@ def __iter__(self):

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

Expand All @@ -245,76 +162,3 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []


class MegatronDecoderPackedText2TextRandomSampler(object):
"""
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
packed to be passed onto the decoder model.

To be used with `pack_samples` as collate_fn
"""

def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.dataset = dataset
self.sequence_length = sequence_length
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)

def __len__(self):
return self.total_samples

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

batch = []
batch_count = 0
token_lens = 0
# Last batch if not complete will be dropped.
for idx in idx_range:
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
if token_lens + tok_len > self.sequence_length:
batch_count += 1
token_lens = 0

if batch_count == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch_count = 0
batch = []
else:
token_lens += tok_len
batch.append(idx)
Loading