Skip to content

Commit

Permalink
fix dropout of flash attention (bigscience-workshop#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana authored Nov 21, 2023
1 parent 8415d03 commit b93495a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,18 @@ def forward(self, q, k, v):

is_causal = self.causal
cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) if get_accelerator().device_name() == 'cuda' else None
self.dropout_p = 0
dropout_p = 0

output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
Expand Down

0 comments on commit b93495a

Please sign in to comment.