Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLM] Train script for non causal decoder #300

Draft
wants to merge 298 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
298 commits
Select commit Hold shift + click to select a range
6ad61b6
made into input and output tokens
May 23, 2022
9131fdd
added eos
May 23, 2022
cb76cd3
added eos
May 23, 2022
531ee68
test text_token
May 24, 2022
a7d1158
test text_token
May 24, 2022
0008cfb
test text_token
May 24, 2022
f1461a8
test text_token
May 24, 2022
ada0f10
test text_token
May 24, 2022
298c9b7
assigned array
May 24, 2022
d2bdff6
assigned array
May 24, 2022
4ec8db3
assigned array
May 24, 2022
10a2b6d
hardcoded sequence length
May 24, 2022
a373a70
check again
May 28, 2022
bdef71b
show sentinal tokens
lintangsutawika May 28, 2022
262fd6c
show sentinal tokens
lintangsutawika May 28, 2022
68a6a93
show sentinal tokens
lintangsutawika May 28, 2022
1c00d4b
show sentinal tokens
lintangsutawika May 28, 2022
8b85f11
add more special tokens
lintangsutawika May 28, 2022
85d204a
changed how mlm data is loaded
lintangsutawika May 28, 2022
4c84274
changed how mlm data is loaded
lintangsutawika May 28, 2022
084245e
changed how mlm data is loaded
lintangsutawika May 28, 2022
32af10e
changed how mlm data is loaded
lintangsutawika May 28, 2022
b6e0e63
changed how mlm data is loaded
lintangsutawika May 28, 2022
2af2e4b
added new script
lintangsutawika May 28, 2022
cc5968e
added new script
lintangsutawika May 28, 2022
cf0b2a0
added new script
lintangsutawika May 28, 2022
fc150a0
try t5 dataset
lintangsutawika May 28, 2022
039f90f
try t5 dataset
lintangsutawika May 28, 2022
7364781
try t5 dataset
lintangsutawika May 28, 2022
5b1100a
try t5 dataset
lintangsutawika May 28, 2022
45102a9
try t5 dataset
lintangsutawika May 28, 2022
7b2ebbf
try t5 dataset
lintangsutawika May 28, 2022
fe8b3dc
try t5 dataset
lintangsutawika May 28, 2022
f456725
try t5 dataset
lintangsutawika May 28, 2022
ae73d8c
try t5 dataset
lintangsutawika May 28, 2022
fae6a0b
try t5 dataset
lintangsutawika May 28, 2022
8185842
try t5 dataset
lintangsutawika May 28, 2022
9deef49
try t5 dataset
lintangsutawika May 28, 2022
1e78a4b
developing
lintangsutawika May 28, 2022
9070929
developing
lintangsutawika May 28, 2022
56c69de
developing
lintangsutawika May 28, 2022
d1ca914
developing
lintangsutawika May 28, 2022
13af623
developing
lintangsutawika May 28, 2022
dbc555e
developing
lintangsutawika May 28, 2022
12b209d
developing
lintangsutawika May 28, 2022
698eff0
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
dae3cc6
test to see output of get_ltor_masks_and_position_ids
lintangsutawika May 29, 2022
5c109c3
add new script
May 29, 2022
2fc9995
add new script
May 29, 2022
ee7af99
add new script
May 29, 2022
b6701a8
changed settings
May 30, 2022
2283e58
changed settings
May 30, 2022
9d00a49
tidy up
May 31, 2022
0298fde
changed tokenizer and position embedding
May 31, 2022
bde07f0
modifying mlm to reflect original implementation
Jun 2, 2022
4c0ca2e
minor fix
Jun 2, 2022
0c05596
minor fix
Jun 2, 2022
30f6924
minor fix
Jun 2, 2022
84408ef
minor fix
Jun 2, 2022
ad964c5
minor fix
Jun 2, 2022
45899e9
minor fix
Jun 2, 2022
0b94597
minor fix
Jun 2, 2022
2b54cc1
minor fix
Jun 2, 2022
ec61627
minor fix
Jun 2, 2022
4448d1d
minor fix
Jun 2, 2022
ecd148c
minor fix
Jun 2, 2022
a99f30f
minor fix
Jun 2, 2022
62d3e3e
minor fix
Jun 2, 2022
a160853
minor fix
Jun 2, 2022
fe205f7
minor fix
Jun 2, 2022
d39bdaf
minor fix
Jun 2, 2022
2530d3e
minor fix
Jun 2, 2022
5e93c47
minor fix
Jun 2, 2022
ad86799
minor fix
Jun 2, 2022
82c8d93
minor fix
Jun 2, 2022
ebf3561
minor fix
Jun 2, 2022
811f975
minor fix
Jun 2, 2022
de7dfc8
minor fix
Jun 2, 2022
be2af77
minor fix
Jun 2, 2022
5e7e18f
minor fix
Jun 2, 2022
24d4f25
minor fix
Jun 2, 2022
5926be1
minor fix
Jun 2, 2022
0f18174
minor fix
Jun 2, 2022
58ce714
minor fix
Jun 2, 2022
05470d7
set correct seq len
Jun 2, 2022
51a23f2
refined sampling method
Jun 8, 2022
43cb2f0
refined sampling method
Jun 8, 2022
901defc
refined sampling method
Jun 8, 2022
3130d7d
refined sampling method
Jun 8, 2022
18eb53d
refined sampling method
Jun 8, 2022
652c545
refined sampling method
Jun 8, 2022
5a49db8
first commit, adding non causal mlm dataset
Jun 8, 2022
81b918c
fixed mlm dataset
Jun 8, 2022
95afc4f
fixed mlm dataset
Jun 8, 2022
c4514d8
fixed mlm dataset
Jun 8, 2022
5cca5af
fixed mlm dataset
Jun 8, 2022
ae95878
fixed mlm dataset
Jun 8, 2022
a03e59f
minor changes
Jun 14, 2022
fa1e072
removed mlm related scripts
Jun 22, 2022
e3ce0a7
removed any scipts not related to dataset, revert arguments
Jun 22, 2022
87e4055
added sampler and test
Jun 23, 2022
0ae7661
added testing data
Jun 23, 2022
71fb5ae
adapted test loader
Jun 23, 2022
be0cea2
Update megatron/data/non_causal_mtf_dataset.py
Jun 24, 2022
9daa376
removed unused files
Jun 24, 2022
6b9e81a
changed with impossible token
Jun 24, 2022
7feec27
enable loading multiple indexed_dataset for each field
Jun 24, 2022
f84f293
minor fix
Jun 24, 2022
2778d8d
data_prefix is set as dict
Jun 24, 2022
61ac4b9
removed sample_idx lines
Jun 24, 2022
62e3fb1
change line from sample_idx to doc_idx
Jun 24, 2022
cb79f09
replace shuffling _build_index_mappings with random.sample of the doc…
Jun 25, 2022
e9cf22a
minor changes
Jun 25, 2022
acd87cd
Cleanup artefacts
Muennighoff Jun 27, 2022
019ed7c
Add packed preprocessing
Muennighoff Jun 28, 2022
7619f7a
Use seq_length arg
Muennighoff Jun 28, 2022
219209a
Add sources & docstrings
Muennighoff Jun 28, 2022
67424d6
added training process for t0
Jun 29, 2022
a7c424e
Update pretrain_t0.py
Jun 29, 2022
51d6c40
Remove a bunch of code that's not needed
thomasw21 Jun 29, 2022
b4e374c
WIP
thomasw21 Jun 30, 2022
0d2fdfd
Cleanup
thomasw21 Jun 30, 2022
126fa34
Add back all configs
thomasw21 Jun 30, 2022
83d2405
Woops
thomasw21 Jun 30, 2022
c93ed5c
Fix tests
thomasw21 Jun 30, 2022
528f5d3
Rename testing files
thomasw21 Jun 30, 2022
8bed302
Do in-place operations
thomasw21 Jun 30, 2022
bd2fede
Do in-place operations
thomasw21 Jun 30, 2022
8593e42
Woops
thomasw21 Jun 30, 2022
a1eb558
Fix typo
thomasw21 Jun 30, 2022
3bddafa
Add test that packing is done optimially via greedy algorithm
thomasw21 Jun 30, 2022
45c9444
Woops
thomasw21 Jun 30, 2022
6f28ae4
added capabilities for padding and prefix lm index
lintangsutawika May 9, 2022
8a4d99b
added adjustments and new dataset
May 9, 2022
ea445b1
added sentinal tokens
May 21, 2022
4070859
made into input and output tokens
May 23, 2022
85e84ec
modifying mlm to reflect original implementation
Jun 2, 2022
3922293
minor fix
Jun 2, 2022
ee6438f
added sampler and test
Jun 23, 2022
a869adf
Enable training
Muennighoff Jun 29, 2022
5ae15ef
Add T0 training test
Muennighoff Jun 30, 2022
efa55ea
Remove artefacts
Muennighoff Jun 30, 2022
f45266d
Remove artefacts
Muennighoff Jun 30, 2022
8029564
WIP
thomasw21 Jun 30, 2022
4faa743
WIP
thomasw21 Jul 1, 2022
3a6d73d
WIP
thomasw21 Jul 1, 2022
ea86bc8
WIP
thomasw21 Jul 1, 2022
638fc56
WIP
thomasw21 Jul 1, 2022
66d2afe
move to cpu for comparison
thomasw21 Jul 1, 2022
3794b86
Use torch_assert_equal
thomasw21 Jul 1, 2022
346b08f
WIP
thomasw21 Jul 1, 2022
4203f6c
Take in account pad + fix inverse
thomasw21 Jul 1, 2022
bcba2b7
Tensor and int can't be compared vi torch_assert_equal
thomasw21 Jul 1, 2022
57156e1
Woops
thomasw21 Jul 1, 2022
45d9218
Test
thomasw21 Jul 1, 2022
959fc71
Woops
thomasw21 Jul 1, 2022
27197fc
Remove unecessary unsqueeze
thomasw21 Jul 1, 2022
b7374e1
Add necessary unsqueeze
thomasw21 Jul 1, 2022
4f6b7d3
I'm stupid
thomasw21 Jul 1, 2022
960b17c
I'm stupid
thomasw21 Jul 1, 2022
2b522d1
Tokenizers returns None when trying to access a non existing value
thomasw21 Jul 1, 2022
a8fcd38
Force gpt2 to have a pad token
thomasw21 Jul 1, 2022
7181de4
Add a test that the packed_masking works in the modeling side
thomasw21 Jul 1, 2022
172306b
Import error
thomasw21 Jul 1, 2022
a4854bd
Tokenizer requires to have pad token
thomasw21 Jul 1, 2022
06c29a9
Turns out that test_model.py did not use deepspeed version of models
thomasw21 Jul 1, 2022
aba48b3
Use train_batch instead
thomasw21 Jul 1, 2022
a9d423a
Make it work via DS
thomasw21 Jul 1, 2022
6a95e25
Make it work via DS
thomasw21 Jul 1, 2022
d6e435b
Make it work via DS
thomasw21 Jul 1, 2022
ca8c04a
Make it work via DS
thomasw21 Jul 1, 2022
f3231db
Make it work via DS
thomasw21 Jul 1, 2022
987e6b4
Make it work via DS
thomasw21 Jul 1, 2022
0b27fb6
Make it work via DS
thomasw21 Jul 1, 2022
1ba5d4a
Woops
thomasw21 Jul 1, 2022
cbab16c
Make it work via DS
thomasw21 Jul 1, 2022
4defbb2
Make it work via DS
thomasw21 Jul 1, 2022
412939c
Make it work via DS
thomasw21 Jul 1, 2022
17a6cc0
Maybe
thomasw21 Jul 1, 2022
cb90679
Make it work via DS
thomasw21 Jul 1, 2022
bd4a3f0
Woops
thomasw21 Jul 1, 2022
6604035
Try having very strict mask
thomasw21 Jul 1, 2022
d98e39a
Try updating the kernel
thomasw21 Jul 1, 2022
8495083
Try updating the kernel
thomasw21 Jul 1, 2022
ef5d4d4
Try updating the kernel
thomasw21 Jul 1, 2022
69912b3
Try updating the kernel
thomasw21 Jul 1, 2022
866fc56
Try updating the kernel
thomasw21 Jul 1, 2022
8e9701b
Try updating the kernel
thomasw21 Jul 1, 2022
15d95fa
Inverse causal masking
thomasw21 Jul 1, 2022
fe4f806
Check that the padding are ignored
thomasw21 Jul 1, 2022
cc2aff5
Fix test
thomasw21 Jul 1, 2022
93cde87
Probably should be in this order:
thomasw21 Jul 1, 2022
f6d717b
Revert "Probably should be in this order:"
thomasw21 Jul 1, 2022
910f93b
Add a test checking that ScaledMaskedSoftmax custom kernel does what …
thomasw21 Jul 1, 2022
75f99ef
Head specific mask is not implemented
thomasw21 Jul 1, 2022
c34f107
Test something out
thomasw21 Jul 2, 2022
ed6131a
Test something out
thomasw21 Jul 2, 2022
3a846a0
Test something out
thomasw21 Jul 2, 2022
5746641
Test something out
thomasw21 Jul 2, 2022
292620c
Test something out
thomasw21 Jul 2, 2022
0e1ef5d
Test something out
thomasw21 Jul 2, 2022
964a275
Test something out
thomasw21 Jul 2, 2022
8b31e9c
Test something out
thomasw21 Jul 2, 2022
723a5b3
Test something out
thomasw21 Jul 2, 2022
65b4ea2
Test something out
thomasw21 Jul 2, 2022
7eaced4
Maybe nothing is wrong
thomasw21 Jul 2, 2022
da9f316
Woops
thomasw21 Jul 2, 2022
8b67bd9
Use bloom instead
thomasw21 Jul 2, 2022
84007bc
Make MTF dataloader an infinite dataloader
thomasw21 Jul 2, 2022
273d420
Work into moving packing logic into a dataset
thomasw21 Jul 2, 2022
688d06e
Woops
thomasw21 Jul 2, 2022
ddc6a61
Woops
thomasw21 Jul 2, 2022
0e34e8d
Woops
thomasw21 Jul 2, 2022
014b8b8
Woops
thomasw21 Jul 2, 2022
c53622a
Woops
thomasw21 Jul 2, 2022
ea221a8
Woops
thomasw21 Jul 2, 2022
3274986
Woops
thomasw21 Jul 2, 2022
9a5bf96
Woops
thomasw21 Jul 2, 2022
d160589
Woops
thomasw21 Jul 2, 2022
c3ab5b9
Woops
thomasw21 Jul 2, 2022
f541076
Woops
thomasw21 Jul 2, 2022
20be5b9
Requires to remember how may epochs
thomasw21 Jul 2, 2022
d9719b6
Find a way to reset states everytime
thomasw21 Jul 2, 2022
4e0c4ca
Find a way to reset states everytime
thomasw21 Jul 2, 2022
48a55b9
Find a way to reset states everytime
thomasw21 Jul 2, 2022
2e469e5
Find a way to reset states everytime
thomasw21 Jul 2, 2022
74e03ec
Find a way to reset states everytime
thomasw21 Jul 2, 2022
f4a4733
Fix bugs
thomasw21 Jul 2, 2022
e1a3767
Cleanup
thomasw21 Jul 2, 2022
efeb55a
Merge remote-tracking branch 'official_repo/main' into thomas/mtf_tra…
thomasw21 Jul 2, 2022
de88ab6
Woops
thomasw21 Jul 2, 2022
d7a6388
Woops
thomasw21 Jul 2, 2022
1c2284f
Woops
thomasw21 Jul 2, 2022
b759a92
Woops
thomasw21 Jul 2, 2022
ef20e57
Woops
thomasw21 Jul 2, 2022
5816adf
Silently skip samples that are too long
thomasw21 Jul 2, 2022
37ad57e
Build the index from scratch everytime
thomasw21 Jul 2, 2022
1572ddc
Prevent empty dataset
thomasw21 Jul 2, 2022
bebb481
Change the condition for empty slice
thomasw21 Jul 2, 2022
5c80699
PR reviews
thomasw21 Jul 3, 2022
985cd02
Revert back changes linked to shutil.copytree
thomasw21 Jul 3, 2022
41e931a
Get test working
thomasw21 Jul 3, 2022
b321a34
Woops
thomasw21 Jul 3, 2022
0450bad
Woops
thomasw21 Jul 3, 2022
de4934f
Fix empty samples
thomasw21 Jul 3, 2022
e3e21f5
Cuda kernel is not strictly equivalent
thomasw21 Jul 3, 2022
16c556c
Update tests/test_model.py
thomasw21 Jul 4, 2022
f2df771
MTF optimize dataloading (#298)
thomasw21 Jul 4, 2022
a45c9cd
Get pretrain on non causal mlm script
thomasw21 Jul 4, 2022
606fdeb
Test
thomasw21 Jul 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Maybe nothing is wrong
thomasw21 committed Jul 2, 2022
commit 7eaced4567e1a0941fb4faa9b759e853e589d438
121 changes: 58 additions & 63 deletions megatron/fused_kernels/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
@@ -47,9 +47,6 @@ __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);

template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }

@@ -62,6 +59,7 @@ __device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }


int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
@@ -109,16 +107,16 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
@@ -127,8 +125,8 @@ __global__ void scaled_softmax_warp_forward(
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + threadIdx.y) * WARP_BATCH;
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
@@ -207,10 +205,10 @@ __global__ void scaled_softmax_warp_forward(
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
}
@@ -220,18 +218,18 @@ __global__ void scaled_softmax_warp_forward(
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
@@ -241,17 +239,14 @@ __global__ void scaled_masked_softmax_warp_forward(

// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + threadIdx.y) * WARP_BATCH;
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}

// int local_seq = blockIdx.x + 1;
// int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
@@ -278,18 +273,18 @@ __global__ void scaled_masked_softmax_warp_forward(
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

if (element_index < batch_element_count) {
int itr_idx = i * element_count * element_count + it * WARP_SIZE;
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);

#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
@@ -317,7 +312,7 @@ __global__ void scaled_masked_softmax_warp_forward(
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (elements[i][it] <= -std::numeric_limits<acc_t>::infinity()) {
elements[i][it] = 0;
elements[i][it] = 0.0f;
} else {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
}
@@ -336,32 +331,32 @@ __global__ void scaled_masked_softmax_warp_forward(
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
if (sum[i] == 0.) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * element_count + it * WARP_SIZE);
if (sum[i] == 0.0f) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE);
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * element_count + it * WARP_SIZE , out);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
} else {
break;
}
}
}
}
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
@@ -370,9 +365,9 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
@@ -412,10 +407,10 @@ __global__ void scaled_masked_softmax_warp_backward(
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}

acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
@@ -443,7 +438,7 @@ __global__ void scaled_masked_softmax_warp_backward(
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
}
@@ -465,11 +460,11 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
@@ -557,12 +552,12 @@ void dispatch_scaled_softmax_forward(

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
@@ -651,12 +646,12 @@ void dispatch_scaled_masked_softmax_forward(

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
@@ -740,4 +735,4 @@ void dispatch_scaled_masked_softmax_backward(
break;
}
}
}
}
14 changes: 12 additions & 2 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,16 @@ def forward(self, input, mask):
assert input.dim() == 4

if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
result = self.forward_fused_softmax(input, mask)
for batch_id in range(len(mask)):
print("Batch id", batch_id)
print(" inputs", input.shape, input[batch_id, 0])
print(" mask", mask.shape, mask[batch_id, 0])
print(" result", result.shape, result[batch_id, 0])
print(" hello", torch.nonzero(~mask[batch_id, 0])[100:150])
print(" bye", torch.nonzero(result[batch_id, 0])[41:100])
print(" all ones?", torch.sum(result, dim=-1))
return result
else:
return self.forward_torch_softmax(input, mask)

@@ -186,8 +195,9 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal and mask is None:
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
assert mask is None

# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)