2323
2424import traceback
2525from collections import defaultdict
26+ from contextlib import contextmanager
2627from dataclasses import dataclass
2728from typing import Any , DefaultDict , Dict , List , Optional , Set , Tuple , Union
2829
@@ -172,6 +173,24 @@ def emit_init(self, op: OpNodeUnion) -> None:
172173 self ._chains .append ([])
173174 self ._chains [self .init_chain_idx ].append (Instruction (op = op ))
174175
176+ @contextmanager
177+ def new_chain (self ):
178+ """Context manager that creates a new instruction chain and redirects emit() to it.
179+
180+ Usage:
181+ with P.new_chain() as chain_idx:
182+ P.emit(MulNode(...)) # goes to the new chain
183+ # P.emit() goes back to the previous chain
184+ """
185+ chain_idx = len (self ._chains )
186+ self ._chains .append ([])
187+ prev_chain = self ._current_chain
188+ self ._current_chain = chain_idx
189+ try :
190+ yield chain_idx
191+ finally :
192+ self ._current_chain = prev_chain
193+
175194 def args (self , node : Node ) -> Tuple [Any , ...]:
176195 return self .slot_map (node .args )
177196
@@ -629,9 +648,12 @@ def _verify_build(self):
629648 info .handler in (noop_handler , PatternHandler .deferred_handler )
630649 or n .users == {}
631650 ):
632- assert (
633- self .slot_manager .get_slot (n ) is None
634- ), f"Did not expect node { n } handled by { info .handler } to have a slot"
651+ # Deferred body nodes may or may not have slots — this is fine.
652+ # Pattern handlers absorb nodes into their body and may set
653+ # slots on them (e.g., GatedDeltaRuleHandler sets getitem[0]'s
654+ # slot to the ScanNode output). Dead nodes (no users) also
655+ # skip the slot check.
656+ pass
635657 else :
636658 assert (
637659 self .slot_manager .get_slot (n ) is not None
@@ -962,6 +984,11 @@ def get_named_data_store(self) -> NamedDataStore:
962984 ``ep.constants`` / ``extra_constants`` (which all use unprefixed
963985 keys). The prefix is applied at the exit boundary — the
964986 ``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``.
987+
988+ To reduce peak memory, each constant is deleted from the EP
989+ immediately after its bytes are added to the NamedDataStore.
990+ This avoids holding two full copies of all constants simultaneously
991+ (important for large models where constants can be 20+ GB).
965992 """
966993 named_data_store = NamedDataStore ()
967994
@@ -971,6 +998,17 @@ def get_named_data_store(self) -> NamedDataStore:
971998 key = lambda x : self ._slot_to_final_tid .get (x [1 ], 0 ),
972999 )
9731000
1001+ # Free EP constants not used by the MLX graph to reduce peak memory.
1002+ used = set (self ._constant_name_to_slot .keys ())
1003+ for ispec in self .ep .graph_signature .input_specs :
1004+ if ispec .arg .name in used and ispec .target is not None :
1005+ used .add (ispec .target )
1006+
1007+ for d in (self .ep ._state_dict , self .ep ._constants ):
1008+ for name in list (d .keys ()):
1009+ if name not in used and isinstance (d [name ], torch .Tensor ):
1010+ del d [name ]
1011+
9741012 logger .debug (f"Adding { len (entries )} constants to NamedDataStore..." )
9751013 for canonical_name , _slot in entries :
9761014 tensor = self ._find_constant_tensor (canonical_name )
@@ -983,6 +1021,15 @@ def get_named_data_store(self) -> NamedDataStore:
9831021 data = t ,
9841022 alignment = 16 ,
9851023 )
1024+
1025+ # Free the original tensor from the EP immediately.
1026+ # The contiguous copy is now serialized as bytes in the
1027+ # NamedDataStore — the EP reference is no longer needed.
1028+ # (It would be deleted by lowered_backend_module.py after
1029+ # preprocess() returns anyway.)
1030+ self ._delete_constant_tensor (canonical_name )
1031+ del tensor , t
1032+
9861033 logger .debug ("Done adding constants to NamedDataStore" )
9871034
9881035 return named_data_store
@@ -1011,17 +1058,33 @@ def get_mutable_buffer_names(self) -> List[str]:
10111058
10121059 def _find_constant_tensor (self , name : str ) -> Optional [torch .Tensor ]:
10131060 """Find a constant tensor by name from various sources."""
1014- if name in self .ep .state_dict :
1015- return self .ep .state_dict [name ]
1016- if name in self .ep .constants :
1017- return self .ep .constants [name ]
1061+ result = self ._resolve_constant (name )
1062+ if result is None :
1063+ return None
1064+
1065+ d , k = result
1066+ return d [k ]
1067+
1068+ def _delete_constant_tensor (self , name : str ) -> None :
1069+ """Delete a constant from the EP to free memory during serialization."""
1070+
1071+ result = self ._resolve_constant (name )
1072+ if result :
1073+ d , k = result
1074+ del d [k ]
1075+
1076+ def _resolve_constant (self , name ):
1077+ """Returns (dict, key) or None."""
1078+ if name in self .ep ._state_dict :
1079+ return self .ep ._state_dict , name
1080+ if name in self .ep ._constants :
1081+ return self .ep ._constants , name
10181082 if name in self .extra_constants :
1019- return self .extra_constants [name ]
1020- # Look up by target
1083+ return self .extra_constants , name
10211084 for ispec in self .ep .graph_signature .input_specs :
10221085 if ispec .arg .name == name and ispec .target is not None :
1023- if ispec .target in self .ep .state_dict :
1024- return self .ep .state_dict [ ispec .target ]
1025- if ispec .target in self .ep .constants :
1026- return self .ep .constants [ ispec .target ]
1086+ if ispec .target in self .ep ._state_dict :
1087+ return self .ep ._state_dict , ispec .target
1088+ if ispec .target in self .ep ._constants :
1089+ return self .ep ._constants , ispec .target
10271090 return None
0 commit comments