Skip to content

Commit

Permalink
update fmha (NVIDIA#1344)
Browse files Browse the repository at this point in the history
  • Loading branch information
yjk21 authored Mar 25, 2022
1 parent a0ed415 commit 3c88451
Show file tree
Hide file tree
Showing 26 changed files with 941 additions and 1,304 deletions.
123 changes: 15 additions & 108 deletions apex/contrib/csrc/fmha/fmha_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f;

set_alpha(params.scale_bmm1, scale_bmm1, acc_type);
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type);

Expand All @@ -83,16 +83,21 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}

std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool is_nl,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {

auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);

int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
Expand All @@ -111,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(false);
}

constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;

const int elts_per_thread = 8 * mmas_m * mmas_n;

auto stream = at::cuda::getCurrentCUDAStream().stream();

TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);

TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())

Expand Down Expand Up @@ -156,9 +149,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

Fused_multihead_attention_fprop_params params;

set_params(params,
set_params(launch_params.params,
batch_size,
seq_len,
num_heads,
Expand All @@ -169,22 +161,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(),
p_dropout);

// number of times random will be generated per thread, to offset philox counter in the random
launch(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;

if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}

launch(params, is_training, stream);
launch(launch_params, /*configure=*/ false);

return { ctx, s };
}


std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
Expand Down Expand Up @@ -270,92 +264,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax };
}

std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);

constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;

auto stream = at::cuda::getCurrentCUDAStream().stream();

TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())

TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())

TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);

const auto sizes = qkv.sizes();

TORCH_CHECK(sizes[THREE_DIM] == 3);

const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();

auto ctx = torch::empty({ total, num_heads, head_size }, opts);

auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);

if( zero_tensors ) {
ctx.zero_();
s.zero_();
}

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());

Fused_multihead_attention_fprop_params params;

set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);

// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;

if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}

launch(params, is_training, num_chunks, stream);

return { ctx, s };
}

std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
Expand Down Expand Up @@ -449,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
53 changes: 43 additions & 10 deletions apex/contrib/csrc/fmha/src/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ constexpr int D_DIM = 3;

struct Qkv_params {
// The QKV matrices.
void *qkv_ptr;
void * __restrict__ qkv_ptr;

// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
Expand All @@ -64,19 +64,19 @@ struct Qkv_params {
struct Fused_multihead_attention_fprop_params : public Qkv_params {

// The dQKV matrices.
void *dqkv_ptr;
void * __restrict__ dqkv_ptr;

// Temporary for dKV.
void *dkv_ptr;
void * __restrict__ dkv_ptr;

// The O matrix (output).
void *o_ptr;
void * __restrict__ o_ptr;

// The stride between rows of O.
int64_t o_stride_in_bytes;

// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void *s_ptr;
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
int64_t s_stride_in_bytes;

Expand All @@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
uint32_t scale_bmm1, scale_softmax, scale_bmm2;

// array of length b+1 holding starting offset of each sequence.
int *cu_seqlens;
int * __restrict__ cu_seqlens;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand All @@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_training_,
bool is_nl_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_training(is_training_)
, is_nl(is_nl_) {
}

size_t elts_per_thread;

cudaDeviceProp * props;

cudaStream_t stream;

bool is_training;

Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
bool is_nl;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);

void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
Expand Down
3 changes: 0 additions & 3 deletions apex/contrib/csrc/fmha/src/fmha/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////


////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Acc, typename A, typename B, int M, int N>
Expand Down
48 changes: 38 additions & 10 deletions apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct Gmem_tile_qkv {

// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx)
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
Expand Down Expand Up @@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
actual_seqlen -= ROWS;
}

inline __device__ void move(int steps) {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}

// The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_;
// The pointer.
Expand Down Expand Up @@ -224,6 +229,11 @@ struct Gmem_tile_o {
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
}

inline __device__ void move(const int steps) {
row_ += ROWS * steps;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
}

// The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_;
// The pointer.
Expand Down Expand Up @@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {

// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx)
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {

// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index.
size_t bidx = bidb * params.h + bidh;

Expand All @@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}

// The pointer in global memory.
char *ptr_;
Expand All @@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
using Type = typename Base::Type;

// Ctor.
template< typename Params >
inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx)
: Base(ptr, params, tidx) {
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}

// Store to global memory.
Expand Down Expand Up @@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
}
}

// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}

// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
Expand All @@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.is_valid(mi, ni, 0, 0) ) {
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
Expand Down
Loading

0 comments on commit 3c88451

Please sign in to comment.