Skip to content

Commit

Permalink
Fixed whitespace.
Browse files Browse the repository at this point in the history
  • Loading branch information
Costin Eseanu committed Jun 3, 2024
1 parent 25150c1 commit c2496f9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 21 deletions.
37 changes: 20 additions & 17 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,22 +542,23 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
1);

if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);

auto prev_value =
torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);
auto prev_key = torch::from_blob(
workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);

auto prev_value = torch::from_blob(
workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);

return {output, prev_key, prev_value};
}
Expand Down Expand Up @@ -1592,7 +1593,9 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
auto output = at::from_blob(output_ptr, input.sizes(), options);
auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options);
auto intermediate_gemm =
at::from_blob(intermediate_ptr, {input.size(0), input.size(1), static_cast<int64_t>(mlp_1_out_neurons)}, options);
at::from_blob(intermediate_ptr,
{input.size(0), input.size(1), static_cast<int64_t>(mlp_1_out_neurons)},
options);

auto act_func_type = static_cast<ActivationFuncType>(activation_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ __device__ __forceinline__ void B_FromSharedToReg(
}
}
#else
assert(("The matrix load functions are only supported on Ampere and newer architectures", false));
assert(
("The matrix load functions are only supported on Ampere and newer architectures", false));
#endif
}
#else
Expand Down Expand Up @@ -102,7 +103,8 @@ __device__ __forceinline__ void B_FromSharedToReg(
}
}
#else
assert(("The matrix load functions are only supported on Ampere and newer architectures", false));
assert(
("The matrix load functions are only supported on Ampere and newer architectures", false));
#endif
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal
}

__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4],
uint32_t *__restrict__ read_RPTR_Frag1,
uint32_t *__restrict__ read_RPTR_Frag2,
uint32_t* __restrict__ read_RPTR_Frag1,
uint32_t* __restrict__ read_RPTR_Frag2,
uint32_t* Scales)
{
uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
Expand Down

0 comments on commit c2496f9

Please sign in to comment.