Skip to content

Commit e8d49d4

Browse files
kirklandsignjpiat
authored andcommitted
Rewrite attention sink from eviction to ring buffer (pytorch#18821)
Differential Revision: D100216687 Pull Request resolved: pytorch#18821
1 parent b0dea90 commit e8d49d4

File tree

10 files changed

+587
-615
lines changed

10 files changed

+587
-615
lines changed

examples/models/llama/BUCK

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,14 @@ fbcode_target(_kind = runtime.python_test,
278278
"source_transformation/test_attention_sink.py",
279279
],
280280
supports_static_listing = False,
281+
preload_deps = [
282+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
283+
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
284+
],
281285
deps = [
282286
"fbsource//third-party/pypi/parameterized:parameterized",
283287
"//caffe2:torch",
288+
"//executorch/extension/pybindings:portable_lib",
284289
":export_library",
285290
],
286291
)

examples/models/llama/attention.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,14 @@ def forward(
550550

551551
if self.use_kv_cache:
552552
assert input_pos is not None
553-
if self.enable_dynamic_shape:
553+
is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False)
554+
555+
if is_ring_buffer:
556+
# Ring buffer models compute their own mask after KV cache
557+
# update; skip start_pos bounds check since start_pos can
558+
# exceed max_context_len for sliding window / attention sink.
559+
attn_mask = None
560+
elif self.enable_dynamic_shape:
554561
start_pos = input_pos[-1].item()
555562
torch._check_is_size(start_pos)
556563
torch._check(start_pos < self.max_context_len)
@@ -569,7 +576,7 @@ def forward(
569576
)
570577
k, v = self.kv_cache.update(input_pos, k, v)
571578

572-
if getattr(self.kv_cache, "is_ring_buffer", False):
579+
if is_ring_buffer:
573580
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
574581
input_pos[0].item(), seqlen
575582
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
base:
2+
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
3+
4+
model:
5+
use_sdpa_with_kv_cache: True
6+
use_kv_cache: True
7+
dtype_override: fp32
8+
enable_dynamic_shape: True
9+
# Attention Sink: "sink_size,window_size"
10+
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
11+
# window_size=124: sliding window size
12+
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
13+
use_attention_sink: "4,124"
14+
15+
export:
16+
# max_context_length controls the RoPE frequency table size.
17+
# It must be >= sink_size + window_size (128), but larger values are
18+
# recommended to support generation beyond the sliding window.
19+
# The model default (e.g., 8192 or 131072) is typically used if not specified.
20+
# For testing, we use the model's default by not setting this explicitly.
21+
22+
quantization:
23+
qmode: 8da4w
24+
group_size: 128
25+
embedding_quantize: 4,32
26+
27+
backend:
28+
xnnpack:
29+
enabled: True
30+
extended_ops: True

examples/models/llama/config/test_llm_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
class TestValidation(unittest.TestCase):
2626
def test_invalid_attention_sink(self):
2727
with self.assertRaises(ValueError):
28-
ModelConfig(use_attention_sink="4,2048")
28+
ModelConfig(use_attention_sink="4")
29+
with self.assertRaises(ValueError):
30+
ModelConfig(use_attention_sink="4,2048,1024")
2931

3032
def test_invalid_local_global_attention_format(self):
3133
with self.assertRaises(ValueError):
@@ -79,7 +81,7 @@ def test_valid_llm_config(self):
7981
),
8082
model=ModelConfig(
8183
dtype_override="fp32",
82-
use_attention_sink="4,2048,1024",
84+
use_attention_sink="4,2048",
8385
use_kv_cache=True,
8486
local_global_attention="[16, 32]",
8587
),

examples/models/llama/eval_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347347
assert llm_config.model.use_attention_sink is not None
348348
assert args.attention_sink_eval_tokens > 0
349349
attention_sink_params = llm_config.model.use_attention_sink.split(",")
350-
assert len(attention_sink_params) == 3
350+
assert len(attention_sink_params) == 2
351351
sink_size = int(attention_sink_params[0])
352352
window_size = int(attention_sink_params[1])
353353

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def build_args_parser() -> argparse.ArgumentParser:
591591
"--use_attention_sink",
592592
default=None,
593593
type=str,
594-
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
594+
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>', e.g., '4,2044'.",
595595
)
596596

597597
parser.add_argument(

examples/models/llama/model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,28 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
203203
from .source_transformation.attention_sink import enable_attention_sink
204204

205205
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
206-
assert len(attention_sink_params) == 3
206+
assert len(attention_sink_params) == 2, (
207+
f"use_attention_sink expects exactly 2 comma-separated values "
208+
f"(sink_size,window_size), got {len(attention_sink_params)}"
209+
)
207210
sink_size = int(attention_sink_params[0])
208211
window_size = int(attention_sink_params[1])
209-
eviction_batch_size = int(attention_sink_params[2])
210212

211-
assert self.llm_config.export.max_context_length == sink_size + window_size
213+
# max_context_length must be >= sink_size + window_size to have enough RoPE frequencies
214+
# A larger max_context_length is allowed (and recommended) to support generation beyond
215+
# the sliding window size.
216+
assert (
217+
self.llm_config.export.max_context_length >= sink_size + window_size
218+
), (
219+
f"max_context_length ({self.llm_config.export.max_context_length}) must be >= "
220+
f"sink_size + window_size ({sink_size + window_size})"
221+
)
212222

213223
self.model_ = enable_attention_sink(
214224
module=self.model_,
215225
params=model_args,
216226
sink_size=sink_size,
217227
window_size=window_size,
218-
eviction_batch_size=eviction_batch_size,
219228
)
220229

221230
missing, unexpected = None, None

0 commit comments

Comments
 (0)