Skip to content

Commit 35f78f1

Browse files
committed
opencl: flash-decoding K-split for f16/q8_0/q4_0 KV decode
1 parent a6d1e3d commit 35f78f1

File tree

4 files changed

+1060
-2
lines changed

4 files changed

+1060
-2
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 205 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,13 +524,19 @@ struct ggml_backend_opencl_context {
524524
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
525525
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_split; // N_SPLIT>1 variant
526526
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
527+
// Flash-decoding K-split: per-split partial kernel + merge kernel for f16 KV.
528+
// Compiled alongside kernels_flash_attn_f32_f16_q1 for the same (dk, dv) set.
529+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1_split;
530+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_merge;
527531
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_q1; // Q=f32, KV=q8_0 decode
532+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_q1_split; // Flash-Decoding Pass 1 for q8_0
528533
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0; // Q=f32, KV=q8_0 prefill (baseline)
529534
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_split; // N_SPLIT>1 variant
530535
std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_wg_size; // wg_size = bm*n_split
531536
std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_nkv_threshold; // use split when n_kv >= this
532537
std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_bm; // per-split BLOCK_M (usually same as f16 bm)
533538
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_q1; // Q=f32, KV=q4_0 decode
539+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_q1_split; // Flash-Decoding Pass 1 for q4_0
534540
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0; // Q=f32, KV=q4_0 prefill (baseline)
535541
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_split; // N_SPLIT>1 variant
536542
std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q4_0_split_wg_size;
@@ -3108,6 +3114,15 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31083114
CL_CHECK((kq1 = clCreateKernel(prog, "flash_attn_f32_f16_q1", &err), err));
31093115
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k;
31103116
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = kq1;
3117+
// Flash-Decoding: extract split + merge kernels from the same program.
3118+
cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_f16_q1_split", &err);
3119+
if (err == CL_SUCCESS) {
3120+
backend_ctx->kernels_flash_attn_f32_f16_q1_split[{dk, dv}] = k_split;
3121+
}
3122+
cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3123+
if (err == CL_SUCCESS) {
3124+
backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3125+
}
31113126
break;
31123127
}
31133128
case FA_VARIANT_Q8_0: {
@@ -3116,6 +3131,17 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31163131
CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q8_0", &err), err));
31173132
backend_ctx->kernels_flash_attn_f32_q8_0_q1[{dk, dv}] = kq1;
31183133
backend_ctx->kernels_flash_attn_f32_q8_0[{dk, dv}] = k;
3134+
// Flash-Decoding: extract q8_0 split + merge kernels.
3135+
cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_q8_0_q1_split", &err);
3136+
if (err == CL_SUCCESS) {
3137+
backend_ctx->kernels_flash_attn_f32_q8_0_q1_split[{dk, dv}] = k_split;
3138+
}
3139+
if (!backend_ctx->kernels_flash_attn_f32_merge.count({dk, dv})) {
3140+
cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3141+
if (err == CL_SUCCESS) {
3142+
backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3143+
}
3144+
}
31193145
break;
31203146
}
31213147
case FA_VARIANT_Q4_0: {
@@ -3124,6 +3150,17 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31243150
CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q4_0", &err), err));
31253151
backend_ctx->kernels_flash_attn_f32_q4_0_q1[{dk, dv}] = kq1;
31263152
backend_ctx->kernels_flash_attn_f32_q4_0[{dk, dv}] = k;
3153+
// Flash-Decoding: extract q4_0 split + merge kernels.
3154+
cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_q4_0_q1_split", &err);
3155+
if (err == CL_SUCCESS) {
3156+
backend_ctx->kernels_flash_attn_f32_q4_0_q1_split[{dk, dv}] = k_split;
3157+
}
3158+
if (!backend_ctx->kernels_flash_attn_f32_merge.count({dk, dv})) {
3159+
cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3160+
if (err == CL_SUCCESS) {
3161+
backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3162+
}
3163+
}
31273164
break;
31283165
}
31293166
case FA_VARIANT_F32_F16_SPLIT: {
@@ -10623,17 +10660,38 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
1062310660
cl_ulong mask_pad_nb2 = 0;
1062410661
cl_ulong mask_pad_nb3 = 0;
1062510662

10663+
// Early FD eligibility probe. Used only to gate the non-FD prefill prep
10664+
// kernels (KV pad, blk-mask classification) — the real FD dispatch still
10665+
// happens below with its own guards. Keep the predicates here in sync with
10666+
// the `if (use_fd)` block further down.
10667+
const int fd_is_causal_probe = (mask == NULL && n_q > 1 && n_q == n_kv);
10668+
// Match the gating used by the actual FD dispatch below. Multi-query FD is
10669+
// DK-gated (see FD_MAX_DK_MULTI comment in the dispatch block). FD is also
10670+
// bypassed for DK>128: the single-pass kernel is already compute-bound at
10671+
// that depth, so the partial-buffer + merge overhead regresses decode.
10672+
const int fd_max_n_q_probe = (d_head_q <= 64) ? 8 : 1;
10673+
const bool fd_will_fire =
10674+
(n_q >= 1 && n_q <= fd_max_n_q_probe && n_kv >= 2048 && !fd_is_causal_probe &&
10675+
d_head_q <= 128 &&
10676+
backend_ctx->kernels_flash_attn_f32_merge.count(dk_dv) > 0 &&
10677+
((is_mixed && backend_ctx->kernels_flash_attn_f32_f16_q1_split.count(dk_dv) > 0) ||
10678+
(is_q8_0 && backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.count(dk_dv) > 0) ||
10679+
(is_q4_0 && backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.count(dk_dv) > 0)));
10680+
1062610681
const int n_q_blocks = n_q > 1 ? (n_q + block_m - 1) / block_m : 0;
1062710682
const int n_kv_blocks = n_kv > 0 ? (n_kv + block_n - 1) / block_n : 0;
10628-
const bool use_mixed_prepass = is_mixed && n_q > 1;
10683+
// Non-FD prefill uses KV padding and a per-tile mask classification. When
10684+
// FD will fire these are pure overhead (the FD kernels don't consume them),
10685+
// so gate on `!fd_will_fire`.
10686+
const bool use_mixed_prepass = is_mixed && n_q > 1 && !fd_will_fire;
1062910687
const bool use_kv_pad = use_mixed_prepass && (n_kv % block_n != 0);
1063010688
// blk prepass: classifies each KV tile as fully-masked / mixed / fully-unmasked
1063110689
// based on the attention mask. Drives two optimisations inside the FA kernel:
1063210690
// 0-blocks → skip the tile entirely (~50% of KV reads on causal PP);
1063310691
// 2-blocks → skip per-row mask lookup (~BLOCK_M×BLOCK_N half reads per tile).
1063410692
// Extended to the native q8_0 / q4_0 prefill kernels: they now accept a blk
1063510693
// pointer and consume the classification identically to f32_f16.
10636-
const bool use_quant_prepass = (use_native_q8_0 || use_native_q4_0);
10694+
const bool use_quant_prepass = (use_native_q8_0 || use_native_q4_0) && !fd_will_fire;
1063710695
const bool use_blk_mask = (use_mixed_prepass || use_quant_prepass) && mask_buffer != NULL;
1063810696

1063910697
if (use_kv_pad) {
@@ -10732,6 +10790,151 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
1073210790
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
1073310791
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
1073410792

10793+
// ============================================================================
10794+
// Flash-Decoding (K-split) for decode / short-query path.
10795+
//
10796+
// Single-query (n_q == 1): decode. Always eligible for DK∈{64,96,128,192,256}
10797+
// when n_kv ≥ FD_MIN_N_KV.
10798+
//
10799+
// Multi-query (2 ≤ n_q ≤ FD_MAX_N_Q_MULTI): speculative decoding / parallel-
10800+
// token generation. Each WG owns one (batch, head, query, split) tuple, so
10801+
// K/V reads are NOT shared across queries — total K/V bandwidth scales as
10802+
// O(n_q · n_kv). That beats the prefill kernel only when per-row compute is
10803+
// light relative to launch/merge overhead. Measured on Adreno X1-85:
10804+
// DK=64 (Llama-3.2-1B) : +26 to +115% at pp4, d=2048..16384
10805+
// DK=96 (Phi-3.5-mini) : -21% to neutral
10806+
// DK=128 (Qwen3-{0.6B,4B}): -16% to -21%
10807+
// So multi-query FD is gated on DK ≤ FD_MAX_DK_MULTI. A future rewrite that
10808+
// shares K/V across queries inside a single WG (cross-Q accumulation) should
10809+
// relax this — see e.g. FlashDecoding++ / FlashAttention-2 multi-token
10810+
// decode paths.
10811+
//
10812+
// Splits the KV dimension across N_SPLITS work-groups per (batch, head, q),
10813+
// each writing a (m_c, l_c, O_c) partial to a temp buffer. A tiny merge
10814+
// kernel reduces partials into the final token. Supports f16 / q8_0 / q4_0
10815+
// KV — merge kernel is type-agnostic. !is_causal required (FD loop has no
10816+
// causal bounds; speculative decoding supplies an explicit mask).
10817+
// ============================================================================
10818+
const int FD_MIN_N_KV = 2048;
10819+
const int FD_KV_PER_SP = 2048;
10820+
const int FD_MAX_N_Q_MULTI = 8;
10821+
const int FD_MAX_DK_MULTI = 64;
10822+
const int FD_MAX_DK = 128;
10823+
const int fd_max_n_q = (d_head_q <= FD_MAX_DK_MULTI) ? FD_MAX_N_Q_MULTI : 1;
10824+
// Pick the Pass 1 kernel based on KV type; all three produce identical
10825+
// partial-buffer layout so Pass 2 (merge) is shared. DK>128 is compute-
10826+
// bound in the single-pass kernel; skipping FD there avoids a measured
10827+
// 6-15% decode regression on Qwen3.5-9B (DK=256) at d4096/d8192.
10828+
cl_kernel fd_k_split = NULL;
10829+
if (n_q >= 1 && n_q <= fd_max_n_q && n_kv >= FD_MIN_N_KV && !is_causal &&
10830+
d_head_q <= FD_MAX_DK &&
10831+
backend_ctx->kernels_flash_attn_f32_merge.count(dk_dv) > 0) {
10832+
if (is_mixed &&
10833+
backend_ctx->kernels_flash_attn_f32_f16_q1_split.count(dk_dv) > 0) {
10834+
fd_k_split = backend_ctx->kernels_flash_attn_f32_f16_q1_split.at(dk_dv);
10835+
} else if (is_q8_0 &&
10836+
backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.count(dk_dv) > 0) {
10837+
fd_k_split = backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.at(dk_dv);
10838+
} else if (is_q4_0 &&
10839+
backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.count(dk_dv) > 0) {
10840+
fd_k_split = backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.at(dk_dv);
10841+
}
10842+
}
10843+
const bool use_fd = (fd_k_split != NULL);
10844+
10845+
if (use_fd) {
10846+
// Choose N_SPLITS: roughly n_kv / 2048, clamped to [2, 16].
10847+
int n_splits = (n_kv + FD_KV_PER_SP - 1) / FD_KV_PER_SP;
10848+
if (n_splits < 2) n_splits = 2;
10849+
if (n_splits > 16) n_splits = 16;
10850+
const int kv_per_split = (n_kv + n_splits - 1) / n_splits;
10851+
10852+
// Partial buffer: n_batch × n_head × n_q × n_splits × (2 + DV) floats.
10853+
// Layout [batch][head][query][split][m,l,O] matches the split kernel's
10854+
// record_idx computation.
10855+
const int fa_partial_floats = 2 + d_head_v;
10856+
const size_t partial_size_bytes =
10857+
(size_t) n_batch * n_head * n_q * n_splits * fa_partial_floats * sizeof(float);
10858+
10859+
ggml_cl_flash_attn_temp_buffer temp_partial;
10860+
cl_int err;
10861+
temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE,
10862+
partial_size_bytes, NULL, &err);
10863+
if (err != CL_SUCCESS) {
10864+
CL_CHECK(clFinish(backend_ctx->queue));
10865+
temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE,
10866+
partial_size_bytes, NULL, &err);
10867+
}
10868+
CL_CHECK(err);
10869+
10870+
// --- Pass 1: per-split partials ---
10871+
cl_kernel k_split = fd_k_split;
10872+
int argi = 0;
10873+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &extra_q->data_device));
10874+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_q));
10875+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &k_data_device));
10876+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_k));
10877+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &v_data_device));
10878+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_v));
10879+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &scale));
10880+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_q));
10881+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_kv));
10882+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head));
10883+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb1));
10884+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb2));
10885+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb3));
10886+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb1));
10887+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb2));
10888+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb3));
10889+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb1));
10890+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb2));
10891+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb3));
10892+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &max_bias));
10893+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m0));
10894+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m1));
10895+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_log2_val));
10896+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &logit_softcap));
10897+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_kv));
10898+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &mask_buffer));
10899+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_mask));
10900+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb1));
10901+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb2));
10902+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb3));
10903+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne2));
10904+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne3));
10905+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &temp_partial.data));
10906+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_splits));
10907+
CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &kv_per_split));
10908+
10909+
const size_t fd_wg = 64; // matches Q1_WG_SIZE in the kernel
10910+
size_t fd_lws[3] = { fd_wg, 1, 1 };
10911+
// gid(2) = q_idx * n_splits + split_idx, dispatched as one dim of size
10912+
// n_splits * n_q so the split kernel can decode both indices.
10913+
size_t fd_gws[3] = { fd_wg, (size_t)(n_head * n_batch), (size_t)(n_splits * n_q) };
10914+
backend_ctx->enqueue_ndrange_kernel(k_split, 3, fd_gws, fd_lws, dst);
10915+
10916+
// --- Pass 2: merge ---
10917+
cl_kernel k_merge = backend_ctx->kernels_flash_attn_f32_merge.at(dk_dv);
10918+
argi = 0;
10919+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &temp_partial.data));
10920+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &extra_o->data_device));
10921+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_o));
10922+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_head));
10923+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_splits));
10924+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb1));
10925+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb2));
10926+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb3));
10927+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &sinks_buffer));
10928+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_sinks));
10929+
CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_q));
10930+
10931+
const size_t merge_wg = (size_t) (d_head_v / 4); // one lane per float4
10932+
size_t merge_lws[3] = { merge_wg, 1, 1 };
10933+
size_t merge_gws[3] = { merge_wg, (size_t)(n_head * n_batch), (size_t) n_q };
10934+
backend_ctx->enqueue_ndrange_kernel(k_merge, 3, merge_gws, merge_lws, dst);
10935+
return;
10936+
}
10937+
1073510938
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
1073610939
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
1073710940
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &k_data_device));

0 commit comments

Comments
 (0)