@@ -524,13 +524,19 @@ struct ggml_backend_opencl_context {
524524 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
525525 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_split; // N_SPLIT>1 variant
526526 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
527+ // Flash-decoding K-split: per-split partial kernel + merge kernel for f16 KV.
528+ // Compiled alongside kernels_flash_attn_f32_f16_q1 for the same (dk, dv) set.
529+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1_split;
530+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_merge;
527531 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_q1; // Q=f32, KV=q8_0 decode
532+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_q1_split; // Flash-Decoding Pass 1 for q8_0
528533 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0; // Q=f32, KV=q8_0 prefill (baseline)
529534 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q8_0_split; // N_SPLIT>1 variant
530535 std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_wg_size; // wg_size = bm*n_split
531536 std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_nkv_threshold; // use split when n_kv >= this
532537 std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q8_0_split_bm; // per-split BLOCK_M (usually same as f16 bm)
533538 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_q1; // Q=f32, KV=q4_0 decode
539+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_q1_split; // Flash-Decoding Pass 1 for q4_0
534540 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0; // Q=f32, KV=q4_0 prefill (baseline)
535541 std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q4_0_split; // N_SPLIT>1 variant
536542 std::map<std::pair<int, int>, int> kernels_flash_attn_f32_q4_0_split_wg_size;
@@ -3108,6 +3114,15 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31083114 CL_CHECK((kq1 = clCreateKernel(prog, "flash_attn_f32_f16_q1", &err), err));
31093115 backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k;
31103116 backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = kq1;
3117+ // Flash-Decoding: extract split + merge kernels from the same program.
3118+ cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_f16_q1_split", &err);
3119+ if (err == CL_SUCCESS) {
3120+ backend_ctx->kernels_flash_attn_f32_f16_q1_split[{dk, dv}] = k_split;
3121+ }
3122+ cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3123+ if (err == CL_SUCCESS) {
3124+ backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3125+ }
31113126 break;
31123127 }
31133128 case FA_VARIANT_Q8_0: {
@@ -3116,6 +3131,17 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31163131 CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q8_0", &err), err));
31173132 backend_ctx->kernels_flash_attn_f32_q8_0_q1[{dk, dv}] = kq1;
31183133 backend_ctx->kernels_flash_attn_f32_q8_0[{dk, dv}] = k;
3134+ // Flash-Decoding: extract q8_0 split + merge kernels.
3135+ cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_q8_0_q1_split", &err);
3136+ if (err == CL_SUCCESS) {
3137+ backend_ctx->kernels_flash_attn_f32_q8_0_q1_split[{dk, dv}] = k_split;
3138+ }
3139+ if (!backend_ctx->kernels_flash_attn_f32_merge.count({dk, dv})) {
3140+ cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3141+ if (err == CL_SUCCESS) {
3142+ backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3143+ }
3144+ }
31193145 break;
31203146 }
31213147 case FA_VARIANT_Q4_0: {
@@ -3124,6 +3150,17 @@ static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_
31243150 CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q4_0", &err), err));
31253151 backend_ctx->kernels_flash_attn_f32_q4_0_q1[{dk, dv}] = kq1;
31263152 backend_ctx->kernels_flash_attn_f32_q4_0[{dk, dv}] = k;
3153+ // Flash-Decoding: extract q4_0 split + merge kernels.
3154+ cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_q4_0_q1_split", &err);
3155+ if (err == CL_SUCCESS) {
3156+ backend_ctx->kernels_flash_attn_f32_q4_0_q1_split[{dk, dv}] = k_split;
3157+ }
3158+ if (!backend_ctx->kernels_flash_attn_f32_merge.count({dk, dv})) {
3159+ cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err);
3160+ if (err == CL_SUCCESS) {
3161+ backend_ctx->kernels_flash_attn_f32_merge[{dk, dv}] = k_merge;
3162+ }
3163+ }
31273164 break;
31283165 }
31293166 case FA_VARIANT_F32_F16_SPLIT: {
@@ -10623,17 +10660,38 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
1062310660 cl_ulong mask_pad_nb2 = 0;
1062410661 cl_ulong mask_pad_nb3 = 0;
1062510662
10663+ // Early FD eligibility probe. Used only to gate the non-FD prefill prep
10664+ // kernels (KV pad, blk-mask classification) — the real FD dispatch still
10665+ // happens below with its own guards. Keep the predicates here in sync with
10666+ // the `if (use_fd)` block further down.
10667+ const int fd_is_causal_probe = (mask == NULL && n_q > 1 && n_q == n_kv);
10668+ // Match the gating used by the actual FD dispatch below. Multi-query FD is
10669+ // DK-gated (see FD_MAX_DK_MULTI comment in the dispatch block). FD is also
10670+ // bypassed for DK>128: the single-pass kernel is already compute-bound at
10671+ // that depth, so the partial-buffer + merge overhead regresses decode.
10672+ const int fd_max_n_q_probe = (d_head_q <= 64) ? 8 : 1;
10673+ const bool fd_will_fire =
10674+ (n_q >= 1 && n_q <= fd_max_n_q_probe && n_kv >= 2048 && !fd_is_causal_probe &&
10675+ d_head_q <= 128 &&
10676+ backend_ctx->kernels_flash_attn_f32_merge.count(dk_dv) > 0 &&
10677+ ((is_mixed && backend_ctx->kernels_flash_attn_f32_f16_q1_split.count(dk_dv) > 0) ||
10678+ (is_q8_0 && backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.count(dk_dv) > 0) ||
10679+ (is_q4_0 && backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.count(dk_dv) > 0)));
10680+
1062610681 const int n_q_blocks = n_q > 1 ? (n_q + block_m - 1) / block_m : 0;
1062710682 const int n_kv_blocks = n_kv > 0 ? (n_kv + block_n - 1) / block_n : 0;
10628- const bool use_mixed_prepass = is_mixed && n_q > 1;
10683+ // Non-FD prefill uses KV padding and a per-tile mask classification. When
10684+ // FD will fire these are pure overhead (the FD kernels don't consume them),
10685+ // so gate on `!fd_will_fire`.
10686+ const bool use_mixed_prepass = is_mixed && n_q > 1 && !fd_will_fire;
1062910687 const bool use_kv_pad = use_mixed_prepass && (n_kv % block_n != 0);
1063010688 // blk prepass: classifies each KV tile as fully-masked / mixed / fully-unmasked
1063110689 // based on the attention mask. Drives two optimisations inside the FA kernel:
1063210690 // 0-blocks → skip the tile entirely (~50% of KV reads on causal PP);
1063310691 // 2-blocks → skip per-row mask lookup (~BLOCK_M×BLOCK_N half reads per tile).
1063410692 // Extended to the native q8_0 / q4_0 prefill kernels: they now accept a blk
1063510693 // pointer and consume the classification identically to f32_f16.
10636- const bool use_quant_prepass = (use_native_q8_0 || use_native_q4_0);
10694+ const bool use_quant_prepass = (use_native_q8_0 || use_native_q4_0) && !fd_will_fire ;
1063710695 const bool use_blk_mask = (use_mixed_prepass || use_quant_prepass) && mask_buffer != NULL;
1063810696
1063910697 if (use_kv_pad) {
@@ -10732,6 +10790,151 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
1073210790 const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
1073310791 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
1073410792
10793+ // ============================================================================
10794+ // Flash-Decoding (K-split) for decode / short-query path.
10795+ //
10796+ // Single-query (n_q == 1): decode. Always eligible for DK∈{64,96,128,192,256}
10797+ // when n_kv ≥ FD_MIN_N_KV.
10798+ //
10799+ // Multi-query (2 ≤ n_q ≤ FD_MAX_N_Q_MULTI): speculative decoding / parallel-
10800+ // token generation. Each WG owns one (batch, head, query, split) tuple, so
10801+ // K/V reads are NOT shared across queries — total K/V bandwidth scales as
10802+ // O(n_q · n_kv). That beats the prefill kernel only when per-row compute is
10803+ // light relative to launch/merge overhead. Measured on Adreno X1-85:
10804+ // DK=64 (Llama-3.2-1B) : +26 to +115% at pp4, d=2048..16384
10805+ // DK=96 (Phi-3.5-mini) : -21% to neutral
10806+ // DK=128 (Qwen3-{0.6B,4B}): -16% to -21%
10807+ // So multi-query FD is gated on DK ≤ FD_MAX_DK_MULTI. A future rewrite that
10808+ // shares K/V across queries inside a single WG (cross-Q accumulation) should
10809+ // relax this — see e.g. FlashDecoding++ / FlashAttention-2 multi-token
10810+ // decode paths.
10811+ //
10812+ // Splits the KV dimension across N_SPLITS work-groups per (batch, head, q),
10813+ // each writing a (m_c, l_c, O_c) partial to a temp buffer. A tiny merge
10814+ // kernel reduces partials into the final token. Supports f16 / q8_0 / q4_0
10815+ // KV — merge kernel is type-agnostic. !is_causal required (FD loop has no
10816+ // causal bounds; speculative decoding supplies an explicit mask).
10817+ // ============================================================================
10818+ const int FD_MIN_N_KV = 2048;
10819+ const int FD_KV_PER_SP = 2048;
10820+ const int FD_MAX_N_Q_MULTI = 8;
10821+ const int FD_MAX_DK_MULTI = 64;
10822+ const int FD_MAX_DK = 128;
10823+ const int fd_max_n_q = (d_head_q <= FD_MAX_DK_MULTI) ? FD_MAX_N_Q_MULTI : 1;
10824+ // Pick the Pass 1 kernel based on KV type; all three produce identical
10825+ // partial-buffer layout so Pass 2 (merge) is shared. DK>128 is compute-
10826+ // bound in the single-pass kernel; skipping FD there avoids a measured
10827+ // 6-15% decode regression on Qwen3.5-9B (DK=256) at d4096/d8192.
10828+ cl_kernel fd_k_split = NULL;
10829+ if (n_q >= 1 && n_q <= fd_max_n_q && n_kv >= FD_MIN_N_KV && !is_causal &&
10830+ d_head_q <= FD_MAX_DK &&
10831+ backend_ctx->kernels_flash_attn_f32_merge.count(dk_dv) > 0) {
10832+ if (is_mixed &&
10833+ backend_ctx->kernels_flash_attn_f32_f16_q1_split.count(dk_dv) > 0) {
10834+ fd_k_split = backend_ctx->kernels_flash_attn_f32_f16_q1_split.at(dk_dv);
10835+ } else if (is_q8_0 &&
10836+ backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.count(dk_dv) > 0) {
10837+ fd_k_split = backend_ctx->kernels_flash_attn_f32_q8_0_q1_split.at(dk_dv);
10838+ } else if (is_q4_0 &&
10839+ backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.count(dk_dv) > 0) {
10840+ fd_k_split = backend_ctx->kernels_flash_attn_f32_q4_0_q1_split.at(dk_dv);
10841+ }
10842+ }
10843+ const bool use_fd = (fd_k_split != NULL);
10844+
10845+ if (use_fd) {
10846+ // Choose N_SPLITS: roughly n_kv / 2048, clamped to [2, 16].
10847+ int n_splits = (n_kv + FD_KV_PER_SP - 1) / FD_KV_PER_SP;
10848+ if (n_splits < 2) n_splits = 2;
10849+ if (n_splits > 16) n_splits = 16;
10850+ const int kv_per_split = (n_kv + n_splits - 1) / n_splits;
10851+
10852+ // Partial buffer: n_batch × n_head × n_q × n_splits × (2 + DV) floats.
10853+ // Layout [batch][head][query][split][m,l,O] matches the split kernel's
10854+ // record_idx computation.
10855+ const int fa_partial_floats = 2 + d_head_v;
10856+ const size_t partial_size_bytes =
10857+ (size_t) n_batch * n_head * n_q * n_splits * fa_partial_floats * sizeof(float);
10858+
10859+ ggml_cl_flash_attn_temp_buffer temp_partial;
10860+ cl_int err;
10861+ temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE,
10862+ partial_size_bytes, NULL, &err);
10863+ if (err != CL_SUCCESS) {
10864+ CL_CHECK(clFinish(backend_ctx->queue));
10865+ temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE,
10866+ partial_size_bytes, NULL, &err);
10867+ }
10868+ CL_CHECK(err);
10869+
10870+ // --- Pass 1: per-split partials ---
10871+ cl_kernel k_split = fd_k_split;
10872+ int argi = 0;
10873+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &extra_q->data_device));
10874+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_q));
10875+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &k_data_device));
10876+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_k));
10877+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &v_data_device));
10878+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_v));
10879+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &scale));
10880+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_q));
10881+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_kv));
10882+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head));
10883+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb1));
10884+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb2));
10885+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb3));
10886+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb1));
10887+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb2));
10888+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb3));
10889+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb1));
10890+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb2));
10891+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb3));
10892+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &max_bias));
10893+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m0));
10894+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m1));
10895+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_log2_val));
10896+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &logit_softcap));
10897+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_kv));
10898+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &mask_buffer));
10899+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_mask));
10900+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb1));
10901+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb2));
10902+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb3));
10903+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne2));
10904+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne3));
10905+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &temp_partial.data));
10906+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_splits));
10907+ CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &kv_per_split));
10908+
10909+ const size_t fd_wg = 64; // matches Q1_WG_SIZE in the kernel
10910+ size_t fd_lws[3] = { fd_wg, 1, 1 };
10911+ // gid(2) = q_idx * n_splits + split_idx, dispatched as one dim of size
10912+ // n_splits * n_q so the split kernel can decode both indices.
10913+ size_t fd_gws[3] = { fd_wg, (size_t)(n_head * n_batch), (size_t)(n_splits * n_q) };
10914+ backend_ctx->enqueue_ndrange_kernel(k_split, 3, fd_gws, fd_lws, dst);
10915+
10916+ // --- Pass 2: merge ---
10917+ cl_kernel k_merge = backend_ctx->kernels_flash_attn_f32_merge.at(dk_dv);
10918+ argi = 0;
10919+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &temp_partial.data));
10920+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &extra_o->data_device));
10921+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_o));
10922+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_head));
10923+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_splits));
10924+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb1));
10925+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb2));
10926+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb3));
10927+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &sinks_buffer));
10928+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_sinks));
10929+ CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_q));
10930+
10931+ const size_t merge_wg = (size_t) (d_head_v / 4); // one lane per float4
10932+ size_t merge_lws[3] = { merge_wg, 1, 1 };
10933+ size_t merge_gws[3] = { merge_wg, (size_t)(n_head * n_batch), (size_t) n_q };
10934+ backend_ctx->enqueue_ndrange_kernel(k_merge, 3, merge_gws, merge_lws, dst);
10935+ return;
10936+ }
10937+
1073510938 CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
1073610939 CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
1073710940 CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &k_data_device));
0 commit comments