3535from ..utils .get_model_step import get_step_from_dir
3636from ..utils .output_dirs import get_step_checkpoint_dir
3737from ..vllm import get_llm , get_worker , openai_server_task , run_on_workers
38- from .train import gc_and_empty_cuda_cache , train
38+ from .train import StopTrainingLoop , gc_and_empty_cuda_cache , train
3939
4040logger = logging .getLogger (__name__ )
4141
@@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol):
5555 def load_lora (self , lora_path : str , load_tensors : bool = True ) -> LoRARequest : ...
5656
5757
58+ class _StopTrainInputs :
59+ """Dedicated sentinel for stopping the background trainer loop."""
60+
61+
62+ _STOP_TRAIN_INPUT = _StopTrainInputs ()
63+ _TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S = 5.0
64+ _TRAIN_TASK_CANCEL_TIMEOUT_S = 1.0
65+
66+
5867def precalculate_new_logprobs (
5968 trainer : "GRPOTrainer" ,
6069 peft_model : "PeftModelForCausalLM" ,
@@ -91,7 +100,7 @@ async def process_train_batch(
91100 packed_tensors : PackedTensors ,
92101 config : types .TrainConfig ,
93102 _config : dev .TrainConfig ,
94- inputs_queue : asyncio .Queue [TrainInputs ],
103+ inputs_queue : asyncio .Queue [TrainInputs | _StopTrainInputs ],
95104 results_queue : asyncio .Queue [dict [str , float ]],
96105 train_task : asyncio .Task [None ],
97106 trainer : "GRPOTrainer" ,
@@ -215,7 +224,7 @@ class UnslothState:
215224 tokenizer : PreTrainedTokenizerBase
216225 peft_model : peft .peft_model .PeftModelForCausalLM
217226 trainer : GRPOTrainer
218- inputs_queue : asyncio .Queue [TrainInputs ]
227+ inputs_queue : asyncio .Queue [TrainInputs | _StopTrainInputs ]
219228 results_queue : asyncio .Queue [dict [str , float ]]
220229 _is_offloaded : bool = False
221230 _pinned_buffers : dict [str , torch .Tensor ] | None = None
@@ -316,6 +325,7 @@ class UnslothService:
316325 _vllm_log_file : Any = field (default = None , repr = False )
317326 _vllm_host : str = "127.0.0.1"
318327 _vllm_port : int = 0
328+ _train_task : asyncio .Task [None ] | None = field (default = None , init = False , repr = False )
319329
320330 @property
321331 def is_dedicated (self ) -> bool :
@@ -326,6 +336,46 @@ def _next_lora_id(self) -> int:
326336 self ._lora_id_counter += 1
327337 return self ._lora_id_counter
328338
339+ def _request_train_task_stop (self ) -> asyncio .Task [None ] | None :
340+ train_task = self ._train_task
341+ if train_task is None :
342+ return None
343+ if train_task .done ():
344+ return train_task
345+
346+ # `_state` is a cached_property. Read from __dict__ directly so shutdown
347+ # does not instantiate the full trainer state solely to stop a task.
348+ state = self .__dict__ .get ("_state" )
349+ if isinstance (state , UnslothState ):
350+ state .inputs_queue .put_nowait (_STOP_TRAIN_INPUT )
351+ return train_task
352+
353+ async def _shutdown_train_task (self ) -> None :
354+ train_task = self ._request_train_task_stop ()
355+ if train_task is None :
356+ return
357+
358+ try :
359+ # Give the trainer loop time to consume the stop sentinel and exit
360+ # normally before falling back to cancellation.
361+ await asyncio .wait_for (
362+ train_task , timeout = _TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S
363+ )
364+ except asyncio .TimeoutError :
365+ train_task .cancel ()
366+ try :
367+ await asyncio .wait_for (train_task , timeout = _TRAIN_TASK_CANCEL_TIMEOUT_S )
368+ except (asyncio .CancelledError , asyncio .TimeoutError ):
369+ pass
370+ except asyncio .CancelledError :
371+ pass
372+ finally :
373+ self ._train_task = None
374+
375+ async def aclose (self ) -> None :
376+ await self ._shutdown_train_task ()
377+ self .close ()
378+
329379 # =========================================================================
330380 # Dedicated mode: vLLM subprocess lifecycle
331381 # =========================================================================
@@ -450,6 +500,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
450500
451501 def close (self ) -> None :
452502 """Terminate vLLM subprocess if running."""
503+ self ._request_train_task_stop ()
453504 if self ._vllm_process is None :
454505 return
455506 self ._vllm_process .terminate ()
@@ -981,17 +1032,19 @@ def _state(self) -> UnslothState:
9811032 trainer .create_optimizer ()
9821033
9831034 # Initialize queues
984- inputs_queue : asyncio .Queue [TrainInputs ] = asyncio .Queue ()
1035+ inputs_queue : asyncio .Queue [TrainInputs | _StopTrainInputs ] = asyncio .Queue ()
9851036 results_queue : asyncio .Queue [dict [str , float ]] = asyncio .Queue ()
9861037
9871038 # Patch trainer _prepare_inputs() to pull from queue
9881039 def _async_prepare_inputs (* _ : Any , ** __ : Any ) -> dict [str , torch .Tensor ]:
989- async def get_inputs () -> TrainInputs :
1040+ async def get_inputs () -> TrainInputs | _StopTrainInputs :
9901041 return await inputs_queue .get ()
9911042
9921043 # Force otherwise synchronous _prepare_inputs() to yield
9931044 # with nested asyncio.run() call
9941045 inputs = asyncio .run (get_inputs ())
1046+ if isinstance (inputs , _StopTrainInputs ):
1047+ raise StopTrainingLoop ()
9951048
9961049 return cast (dict [str , torch .Tensor ], inputs )
9971050
0 commit comments