@@ -371,8 +371,41 @@ def replace_kv_cache_with_custom_kv_cache(module):
371371
372372
373373def _replace_kv_cache_with_custom_kv_cache (module ):
374+ # Import here to avoid circular imports
375+ from executorch .examples .models .llama .source_transformation .attention_sink import (
376+ KVCacheWithAttentionSink ,
377+ )
378+
374379 for name , child in module .named_children ():
375- if isinstance (child , KVCache ):
380+ if isinstance (child , KVCacheWithAttentionSink ):
381+ # Replace with custom op variant for performance
382+ setattr (
383+ module ,
384+ name ,
385+ CustomKVCacheWithAttentionSink .from_kv_cache_with_attention_sink (child ),
386+ )
387+ # If parent has SDPACustom, enable explicit mask mode
388+ sdpa = getattr (module , "SDPA" , None )
389+ if sdpa is not None and hasattr (sdpa , "use_attention_mask" ):
390+ sdpa .use_attention_mask = True
391+ elif isinstance (child , RingKVCache ):
392+ # RingKVCache (e.g., from attention sink with sink_size=0) needs
393+ # CustomRingKVCache, not plain CustomKVCache
394+ setattr (
395+ module ,
396+ name ,
397+ CustomRingKVCache (
398+ child .max_batch_size ,
399+ child .window_size ,
400+ child .n_heads ,
401+ child .head_dim ,
402+ dtype = child .k_cache .dtype ,
403+ ),
404+ )
405+ sdpa = getattr (module , "SDPA" , None )
406+ if sdpa is not None and hasattr (sdpa , "use_attention_mask" ):
407+ sdpa .use_attention_mask = True
408+ elif isinstance (child , KVCache ):
376409 cache_shape = child .k_cache .shape
377410 cache_dtype = child .k_cache .dtype
378411 max_batch_size , n_heads , max_context_length , head_dim = cache_shape
@@ -466,6 +499,89 @@ def from_quantized_kv_cache(
466499 )
467500
468501
502+ class CustomKVCacheWithAttentionSink (CustomKVCache ):
503+ """
504+ CustomKVCache variant for attention sink models.
505+
506+ Uses the custom update_cache_with_indices op for performance while
507+ supporting sink tokens (fixed slots) + ring buffer (sliding window).
508+ Modeled after CustomRingKVCache but with CachePositionsManagerWithSink.
509+ """
510+
511+ def __init__ (
512+ self ,
513+ max_batch_size ,
514+ n_heads ,
515+ head_dim ,
516+ window_size ,
517+ sink_size ,
518+ dtype = torch .float32 ,
519+ ):
520+ # Total cache size: sink slots + ring buffer (2x window for wrap safety)
521+ total_cache_size = sink_size + window_size * 2
522+ super ().__init__ (max_batch_size , total_cache_size , n_heads , head_dim , dtype )
523+ from executorch .examples .models .llama .source_transformation .attention_sink import (
524+ _create_causal_mask_for_attention_sink ,
525+ CachePositionsManagerWithSink ,
526+ )
527+
528+ self .cache_positions_manager = CachePositionsManagerWithSink (
529+ total_cache_size , sink_size
530+ )
531+ self .is_ring_buffer = True
532+ self .window_size = window_size
533+ self .sink_size = sink_size
534+ self ._create_causal_mask_for_attention_sink = (
535+ _create_causal_mask_for_attention_sink
536+ )
537+
538+ def create_causal_mask_for_ring_buffer (self , start_pos , seq_len ):
539+ cache_positions = self .cache_positions_manager .cache_positions
540+ if self .sink_size > 0 :
541+ return self ._create_causal_mask_for_attention_sink (
542+ cache_positions , self .window_size , self .sink_size , start_pos , seq_len
543+ )
544+ else :
545+ return _create_causal_mask_for_ring_buffer (
546+ cache_positions , self .window_size , start_pos , seq_len
547+ )
548+
549+ def update (self , input_pos , k_val , v_val ):
550+ seq_len = k_val .transpose (1 , 2 ).size (1 )
551+ assert seq_len <= self .k_cache .size (
552+ 1
553+ ), f"Update sequence length({ seq_len } ) for kv cache must be smaller than the cache size({ self .k_cache .size (1 )} )"
554+ # Verify that window tokens don't exceed ring_size, which would cause
555+ # duplicate indices in update_cache_with_indices (scatter-style update).
556+ start_pos = input_pos [0 ].item ()
557+ num_sink_tokens = max (0 , min (seq_len , self .sink_size - start_pos ))
558+ num_window_tokens = seq_len - num_sink_tokens
559+ assert num_window_tokens <= self .cache_positions_manager .ring_size , (
560+ f"Window tokens ({ num_window_tokens } ) exceed ring buffer capacity "
561+ f"({ self .cache_positions_manager .ring_size } ), which would cause "
562+ f"non-deterministic behavior with update_cache_with_indices"
563+ )
564+ indices = self .cache_positions_manager .calculate_positions_and_update_indices (
565+ input_pos , seq_len
566+ )
567+ indices = indices .unsqueeze (0 )
568+
569+ return super ().update (input_pos , k_val , v_val , indices )
570+
571+ @classmethod
572+ def from_kv_cache_with_attention_sink (cls , kv_cache ):
573+ """Create from an existing KVCacheWithAttentionSink."""
574+ max_batch_size , n_heads , _ , head_dim = kv_cache .k_cache .shape
575+ return cls (
576+ max_batch_size ,
577+ n_heads ,
578+ head_dim ,
579+ kv_cache .window_size ,
580+ kv_cache .sink_size ,
581+ dtype = kv_cache .k_cache .dtype ,
582+ )
583+
584+
469585class CustomRingKVCache (CustomKVCache ):
470586 def __init__ (
471587 self ,
0 commit comments