Skip to content

Commit bb07e8b

Browse files
authored
Integrate attention sink into ET LLM export and runner (#18860)
Differential Revision: D100216686 Pull Request resolved: #18860
1 parent e8487f3 commit bb07e8b

4 files changed

Lines changed: 192 additions & 13 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,16 +1777,17 @@ def _get_source_transforms( # noqa
17771777
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
17781778

17791779
if use_sdpa_with_kv_cache:
1780-
transforms.append(replace_kv_cache_with_custom_kv_cache)
1781-
# todo: do this optionally
1782-
# if use attention mask instead of causal attention
1783-
# then create partial function that sets use_attention_mask=True
1780+
# Replace SDPA first, then KV cache. Order matters: the KV cache
1781+
# replacement sets SDPACustom.use_attention_mask=True for ring buffer
1782+
# models (attention sink, sliding window). If SDPA is replaced after,
1783+
# a new SDPACustom(use_attention_mask=False) would overwrite it.
17841784
if use_attention_mask_for_custom_sdpa:
17851785
transforms.append(
17861786
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
17871787
)
17881788
else:
17891789
transforms.append(replace_sdpa_with_custom_op)
1790+
transforms.append(replace_kv_cache_with_custom_kv_cache)
17901791

17911792
if quantize_kv_cache:
17921793
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,41 @@ def replace_kv_cache_with_custom_kv_cache(module):
371371

372372

373373
def _replace_kv_cache_with_custom_kv_cache(module):
374+
# Import here to avoid circular imports
375+
from executorch.examples.models.llama.source_transformation.attention_sink import (
376+
KVCacheWithAttentionSink,
377+
)
378+
374379
for name, child in module.named_children():
375-
if isinstance(child, KVCache):
380+
if isinstance(child, KVCacheWithAttentionSink):
381+
# Replace with custom op variant for performance
382+
setattr(
383+
module,
384+
name,
385+
CustomKVCacheWithAttentionSink.from_kv_cache_with_attention_sink(child),
386+
)
387+
# If parent has SDPACustom, enable explicit mask mode
388+
sdpa = getattr(module, "SDPA", None)
389+
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
390+
sdpa.use_attention_mask = True
391+
elif isinstance(child, RingKVCache):
392+
# RingKVCache (e.g., from attention sink with sink_size=0) needs
393+
# CustomRingKVCache, not plain CustomKVCache
394+
setattr(
395+
module,
396+
name,
397+
CustomRingKVCache(
398+
child.max_batch_size,
399+
child.window_size,
400+
child.n_heads,
401+
child.head_dim,
402+
dtype=child.k_cache.dtype,
403+
),
404+
)
405+
sdpa = getattr(module, "SDPA", None)
406+
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
407+
sdpa.use_attention_mask = True
408+
elif isinstance(child, KVCache):
376409
cache_shape = child.k_cache.shape
377410
cache_dtype = child.k_cache.dtype
378411
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
@@ -466,6 +499,89 @@ def from_quantized_kv_cache(
466499
)
467500

468501

502+
class CustomKVCacheWithAttentionSink(CustomKVCache):
503+
"""
504+
CustomKVCache variant for attention sink models.
505+
506+
Uses the custom update_cache_with_indices op for performance while
507+
supporting sink tokens (fixed slots) + ring buffer (sliding window).
508+
Modeled after CustomRingKVCache but with CachePositionsManagerWithSink.
509+
"""
510+
511+
def __init__(
512+
self,
513+
max_batch_size,
514+
n_heads,
515+
head_dim,
516+
window_size,
517+
sink_size,
518+
dtype=torch.float32,
519+
):
520+
# Total cache size: sink slots + ring buffer (2x window for wrap safety)
521+
total_cache_size = sink_size + window_size * 2
522+
super().__init__(max_batch_size, total_cache_size, n_heads, head_dim, dtype)
523+
from executorch.examples.models.llama.source_transformation.attention_sink import (
524+
_create_causal_mask_for_attention_sink,
525+
CachePositionsManagerWithSink,
526+
)
527+
528+
self.cache_positions_manager = CachePositionsManagerWithSink(
529+
total_cache_size, sink_size
530+
)
531+
self.is_ring_buffer = True
532+
self.window_size = window_size
533+
self.sink_size = sink_size
534+
self._create_causal_mask_for_attention_sink = (
535+
_create_causal_mask_for_attention_sink
536+
)
537+
538+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
539+
cache_positions = self.cache_positions_manager.cache_positions
540+
if self.sink_size > 0:
541+
return self._create_causal_mask_for_attention_sink(
542+
cache_positions, self.window_size, self.sink_size, start_pos, seq_len
543+
)
544+
else:
545+
return _create_causal_mask_for_ring_buffer(
546+
cache_positions, self.window_size, start_pos, seq_len
547+
)
548+
549+
def update(self, input_pos, k_val, v_val):
550+
seq_len = k_val.transpose(1, 2).size(1)
551+
assert seq_len <= self.k_cache.size(
552+
1
553+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(1)})"
554+
# Verify that window tokens don't exceed ring_size, which would cause
555+
# duplicate indices in update_cache_with_indices (scatter-style update).
556+
start_pos = input_pos[0].item()
557+
num_sink_tokens = max(0, min(seq_len, self.sink_size - start_pos))
558+
num_window_tokens = seq_len - num_sink_tokens
559+
assert num_window_tokens <= self.cache_positions_manager.ring_size, (
560+
f"Window tokens ({num_window_tokens}) exceed ring buffer capacity "
561+
f"({self.cache_positions_manager.ring_size}), which would cause "
562+
f"non-deterministic behavior with update_cache_with_indices"
563+
)
564+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
565+
input_pos, seq_len
566+
)
567+
indices = indices.unsqueeze(0)
568+
569+
return super().update(input_pos, k_val, v_val, indices)
570+
571+
@classmethod
572+
def from_kv_cache_with_attention_sink(cls, kv_cache):
573+
"""Create from an existing KVCacheWithAttentionSink."""
574+
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
575+
return cls(
576+
max_batch_size,
577+
n_heads,
578+
head_dim,
579+
kv_cache.window_size,
580+
kv_cache.sink_size,
581+
dtype=kv_cache.k_cache.dtype,
582+
)
583+
584+
469585
class CustomRingKVCache(CustomKVCache):
470586
def __init__(
471587
self,

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,48 @@ def test_beyond_context_window_basic(self):
397397
self.assertTrue(
398398
torch.isfinite(out).all(), "Output contains non-finite values"
399399
)
400+
401+
def test_beyond_context_window_custom_sdpa(self):
402+
"""Generate tokens beyond context window with custom SDPA + custom KV cache."""
403+
sink_size = 4
404+
window_size = 16
405+
args = self._make_args(max_context_len=128)
406+
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)
407+
408+
# Verify KV caches were replaced with CustomKVCacheWithAttentionSink
409+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
410+
CustomKVCacheWithAttentionSink,
411+
)
412+
413+
found_custom_cache = False
414+
for m in model.modules():
415+
if isinstance(m, CustomKVCacheWithAttentionSink):
416+
found_custom_cache = True
417+
break
418+
self.assertTrue(
419+
found_custom_cache, "Expected CustomKVCacheWithAttentionSink in model"
420+
)
421+
422+
# Generate 80 tokens — well beyond KV cache size of 36
423+
outputs = self._run_generation(model, args, num_tokens=80)
424+
425+
self.assertEqual(len(outputs), 77)
426+
for out in outputs:
427+
self.assertTrue(
428+
torch.isfinite(out).all(), "Output contains non-finite values"
429+
)
430+
431+
def test_sink_zero_custom_sdpa(self):
432+
"""Degenerate case: sink_size=0 with custom SDPA (pure ring buffer)."""
433+
sink_size = 0
434+
window_size = 16
435+
args = self._make_args(max_context_len=128)
436+
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)
437+
438+
outputs = self._run_generation(model, args, num_tokens=60)
439+
440+
self.assertEqual(len(outputs), 57)
441+
for out in outputs:
442+
self.assertTrue(
443+
torch.isfinite(out).all(), "Output contains non-finite values"
444+
)

extension/llm/runner/text_llm_runner.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ Error TextLLMRunner::generate(
109109

110110
stats_->inference_start_ms = time_in_ms();
111111

112+
// Get max_seq_len for single prefill chunk limit
113+
int64_t max_seq_len = metadata_.at(kMaxSeqLen);
112114
int64_t max_context_len = metadata_.at(kMaxContextLen);
113115

114116
uint64_t cur_token = 0;
@@ -137,13 +139,26 @@ Error TextLLMRunner::generate(
137139
InvalidArgument,
138140
"Expected at least 1 prompt token");
139141
ET_CHECK_OR_RETURN_ERROR(
140-
pos_ + num_prompt_tokens < max_context_len,
142+
num_prompt_tokens <= max_seq_len,
141143
InvalidArgument,
142-
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
143-
", Max seq length exceeded - please increase max seq len value in your export script",
144-
pos_,
144+
"num_prompt_tokens %d > max_seq_len %" PRId64
145+
", Single prefill chunk too large - please reduce prompt size or increase max_seq_len",
145146
num_prompt_tokens,
146-
max_context_len);
147+
max_seq_len);
148+
// For non-sliding-window models, also check that we won't exceed
149+
// KV cache capacity. Sliding window models (where max_seq_len <
150+
// max_context_len) handle position wrapping internally.
151+
if (max_seq_len >= max_context_len) {
152+
ET_CHECK_OR_RETURN_ERROR(
153+
pos_ + num_prompt_tokens < max_context_len,
154+
InvalidArgument,
155+
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
156+
", Max seq length exceeded - please increase max seq len value in "
157+
"your export script",
158+
pos_,
159+
num_prompt_tokens,
160+
max_context_len);
161+
}
147162

148163
// print prompts
149164
if (config.echo) {
@@ -167,9 +182,11 @@ Error TextLLMRunner::generate(
167182
prefill_next_token_.reset();
168183
}
169184

170-
// Resolve max_new_tokens. pos_ now reflects all occupied positions
171-
// (including prompt tokens just prefilled).
172-
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
185+
// For sliding window models, the ring buffer recycles space — pos_ doesn't
186+
// represent consumed capacity, so pass 0 to get the full budget.
187+
int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_;
188+
int max_new_tokens =
189+
config.resolve_max_new_tokens(max_context_len, effective_pos);
173190

174191
ET_LOG(
175192
Info,

0 commit comments

Comments
 (0)