Skip to content

Commit

Permalink
hotfix: accelerate plan speed of fa3 template (#690)
Browse files Browse the repository at this point in the history
The fa3 template's plan speed is very slow because we overestimate the
workspace size that needs to be transferred from CPU to GPU, this PR
fixes the issue.

cc @nandor @zhyncs
  • Loading branch information
yzh119 authored Dec 23, 2024
1 parent bcf7a3e commit db8f04d
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 145 deletions.
12 changes: 10 additions & 2 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,19 @@ def get_cu_file_str(
def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
Expand Down
12 changes: 10 additions & 2 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,19 @@ def get_cu_file_str(
def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
Expand Down
11 changes: 0 additions & 11 deletions csrc/aot_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,6 @@
#define DISPATCH_mask_mode(expr, const_expr, ...) \
_DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__))

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()

#define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \
[&]() -> bool { \
if (kv_dtype == q_dtype) { \
Expand Down
70 changes: 39 additions & 31 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO,
typename IdType>
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO,
typename IdType>
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);

Expand All @@ -47,9 +47,9 @@ using namespace flashinfer;
std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int64_t cuda_stream) {
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand All @@ -61,12 +61,13 @@ std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);

cudaError_t status = PrefillSM90Plan(
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
cudaError_t status =
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, causal,
enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);

TORCH_CHECK(status == cudaSuccess,
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
Expand Down Expand Up @@ -151,19 +152,23 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);

bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;

return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
cudaError_t status =
BatchPrefillWithRaggedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
AttentionVariant>(params, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCacheSM90Run failed with error: ",
cudaGetErrorString(status));
return true;
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCacheSM90Run failed with error: ",
cudaGetErrorString(status));
return true;
});
});
});
});
Expand Down Expand Up @@ -259,20 +264,23 @@ void BatchPrefillWithPagedKVCacheSM90Run(
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;

return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
cudaError_t status =
BatchPrefillWithPagedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
AttentionVariant>(params, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
cudaGetErrorString(status));
return true;
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
cudaGetErrorString(status));
return true;
});
});
});
});
Expand Down
6 changes: 3 additions & 3 deletions csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int64_t cuda_stream);
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);

void BatchPrefillWithRaggedKVCacheSM90Run(
unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down
11 changes: 11 additions & 0 deletions csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@
return __VA_ARGS__(); \
}

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()

inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name,
const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ",
Expand Down
Loading

0 comments on commit db8f04d

Please sign in to comment.