11import asyncio
2- from dataclasses import dataclass
2+ from dataclasses import dataclass , field
33import datetime
44from functools import cached_property
55import json
6+ import logging
67import os
78from pathlib import Path
89import shlex
910import shutil
1011import subprocess
11- from typing import Any , AsyncIterator
12+ import sys
13+ from typing import Any , AsyncIterator , Literal
1214
1315from peft .tuners .lora .config import LoraConfig
1416from pydantic import BaseModel
2123
2224from .. import dev , types
2325from ..dev .get_model_config import default_target_modules
26+ from ..dev .validate import is_dedicated_mode
2427from ..local .checkpoints import get_last_checkpoint_dir
2528from ..preprocessing .pack import DiskPackedTensors
2629from ..preprocessing .tokenize import SFTBatch
@@ -49,6 +52,9 @@ class MegatronTrainingJob(BaseModel):
4952)
5053
5154
55+ logger = logging .getLogger (__name__ )
56+
57+
5258@dataclass
5359class MegatronService :
5460 model_name : str
@@ -60,6 +66,24 @@ class MegatronService:
6066 _lora_id_counter : int = 1
6167 _megatron_process : asyncio .subprocess .Process | None = None
6268 _optimizer_state_path : str | None = None
69+ _vllm_process : subprocess .Popen | None = field (default = None , repr = False ) # type: ignore[type-arg]
70+ _vllm_log_file : Any = field (default = None , repr = False )
71+ _vllm_host : str = "127.0.0.1"
72+ _vllm_port : int = 0
73+
74+ @property
75+ def is_dedicated (self ) -> bool :
76+ return is_dedicated_mode (self .config )
77+
78+ @property
79+ def rollout_weights_mode (self ) -> Literal ["lora" , "merged" ]:
80+ mode = self .config .get ("rollout_weights_mode" , "lora" )
81+ assert mode in {"lora" , "merged" }
82+ return mode
83+
84+ @property
85+ def _vllm_base_url (self ) -> str :
86+ return f"http://{ self ._vllm_host } :{ self ._vllm_port } "
6387
6488 def _next_lora_id (self ) -> int :
6589 self ._lora_id_counter += 1
@@ -171,6 +195,144 @@ def _ensure_lora_adapter_config(
171195 return
172196 self ._default_lora_adapter_config ().save_pretrained (lora_path )
173197
198+ def _resolve_active_lora_path (self ) -> str :
199+ lora_path = get_last_checkpoint_dir (self .output_dir )
200+ if lora_path is None :
201+ lora_path = get_step_checkpoint_dir (self .output_dir , 0 )
202+ self ._latest_step = 0
203+ else :
204+ self ._latest_step = get_step_from_dir (self .output_dir )
205+ self ._ensure_identity_lora (lora_path )
206+ self ._ensure_lora_adapter_config (lora_path )
207+ return lora_path
208+
209+ async def _start_vllm_subprocess (
210+ self ,
211+ lora_path : str ,
212+ port : int ,
213+ config : dev .OpenAIServerConfig | None ,
214+ ) -> tuple [str , int ]:
215+ import atexit
216+ import httpx
217+
218+ inference_gpu_ids = self .config ["inference_gpu_ids" ]
219+ cuda_devices = "," .join (str (gpu_id ) for gpu_id in inference_gpu_ids )
220+
221+ server_args : dict [str , object ] = {
222+ "return_tokens_as_token_ids" : True ,
223+ "enable_auto_tool_choice" : True ,
224+ "tool_call_parser" : "hermes" ,
225+ }
226+ if config and "server_args" in config :
227+ server_args .update (dict (config ["server_args" ]))
228+ for key in ("port" , "host" , "lora_modules" , "api_key" ):
229+ server_args .pop (key , None )
230+
231+ engine_args = dict (self .config .get ("engine_args" , {}))
232+ if config and "engine_args" in config :
233+ engine_args .update (dict (config ["engine_args" ]))
234+ engine_args .setdefault ("generation_config" , "vllm" )
235+ engine_args ["enable_lora" ] = True
236+ engine_args .setdefault ("max_loras" , 2 )
237+ for key in ("model" , "served_model_name" , "enable_sleep_mode" ):
238+ engine_args .pop (key , None )
239+
240+ cmd = [
241+ sys .executable ,
242+ "-m" ,
243+ "art.vllm.dedicated_server" ,
244+ f"--model={ self .base_model } " ,
245+ f"--port={ port } " ,
246+ f"--host={ self ._vllm_host } " ,
247+ f"--cuda-visible-devices={ cuda_devices } " ,
248+ f"--lora-path={ lora_path } " ,
249+ f"--served-model-name={ self .model_name } @{ self ._latest_step } " ,
250+ f"--rollout-weights-mode={ self .rollout_weights_mode } " ,
251+ f"--engine-args-json={ json .dumps (engine_args )} " ,
252+ f"--server-args-json={ json .dumps (server_args )} " ,
253+ ]
254+
255+ log_dir = os .path .join (self .output_dir , "logs" )
256+ os .makedirs (log_dir , exist_ok = True )
257+ self ._vllm_log_file = open (
258+ os .path .join (log_dir , "vllm-dedicated.log" ), "w" , buffering = 1
259+ )
260+ self ._vllm_process = subprocess .Popen (
261+ cmd ,
262+ stdout = self ._vllm_log_file ,
263+ stderr = subprocess .STDOUT ,
264+ bufsize = 1 ,
265+ )
266+ self ._vllm_port = port
267+
268+ timeout = float (os .environ .get ("ART_DEDICATED_VLLM_TIMEOUT" , 600 ))
269+ elapsed = 0.0
270+ async with httpx .AsyncClient () as client :
271+ while elapsed < timeout :
272+ if self ._vllm_process .poll () is not None :
273+ raise RuntimeError (
274+ "vLLM subprocess exited with code "
275+ f"{ self ._vllm_process .returncode } . "
276+ f"Check logs at { log_dir } /vllm-dedicated.log"
277+ )
278+ try :
279+ response = await client .get (
280+ f"{ self ._vllm_base_url } /v1/models" ,
281+ timeout = 5.0 ,
282+ )
283+ if response .status_code == 200 :
284+ break
285+ except (httpx .ConnectError , httpx .ReadTimeout ):
286+ pass
287+ await asyncio .sleep (1.0 )
288+ elapsed += 1.0
289+ else :
290+ self ._stop_vllm_subprocess ()
291+ raise TimeoutError (
292+ f"vLLM subprocess did not become ready within { timeout } s. "
293+ f"Check logs at { log_dir } /vllm-dedicated.log"
294+ )
295+
296+ atexit .register (self .close )
297+ logger .info ("vLLM subprocess ready on port %d (GPUs: %s)" , port , cuda_devices )
298+ return self ._vllm_host , self ._vllm_port
299+
300+ async def _reload_adapter (self , checkpoint_path : str , step : int ) -> None :
301+ import httpx
302+
303+ async with httpx .AsyncClient () as client :
304+ response = await client .post (
305+ f"{ self ._vllm_base_url } /v1/load_lora_adapter" ,
306+ json = {
307+ "lora_name" : f"{ self .model_name } @{ step } " ,
308+ "lora_path" : checkpoint_path ,
309+ "load_inplace" : True ,
310+ },
311+ timeout = 60.0 ,
312+ )
313+ response .raise_for_status ()
314+ self ._latest_step = step
315+
316+ def _stop_vllm_subprocess (self ) -> None :
317+ if self ._vllm_process is not None :
318+ self ._vllm_process .terminate ()
319+ try :
320+ self ._vllm_process .wait (timeout = 5 )
321+ except subprocess .TimeoutExpired :
322+ self ._vllm_process .kill ()
323+ self ._vllm_process .wait ()
324+ self ._vllm_process = None
325+ if self ._vllm_log_file is not None :
326+ self ._vllm_log_file .close ()
327+ self ._vllm_log_file = None
328+
329+ def _stop_megatron_process (self ) -> None :
330+ if self ._megatron_process is None :
331+ return
332+ if self ._megatron_process .returncode is None :
333+ self ._megatron_process .terminate ()
334+ self ._megatron_process = None
335+
174336 async def _add_lora_aliases (
175337 self , llm : AsyncLLM , step : int , checkpoint_dir : str
176338 ) -> None :
@@ -186,6 +348,10 @@ async def _add_lora_aliases(
186348 self ._latest_step = step
187349
188350 async def register_lora_for_step (self , step : int , checkpoint_dir : str ) -> None :
351+ if self .is_dedicated :
352+ assert self .rollout_weights_mode == "lora"
353+ await self ._reload_adapter (checkpoint_dir , step )
354+ return
189355 llm = await self .llm
190356 await llm .pause_generation ()
191357 await self ._add_lora_aliases (llm , step , checkpoint_dir )
@@ -209,29 +375,36 @@ async def _ensure_megatron_running(self) -> None:
209375 subprocess .run (["pkill" , "-9" , "megatron-service" ], check = False )
210376 train_script = Path (__file__ ).parent / "train.py"
211377 project_root = Path (__file__ ).resolve ().parents [3 ]
212- num_gpus = torch .cuda .device_count ()
213- os .environ ["MODEL_IDENTIFIER" ] = self .base_model
378+ launch_env = os .environ .copy ()
379+ if self .is_dedicated :
380+ trainer_gpu_ids = self .config ["trainer_gpu_ids" ]
381+ num_gpus = len (trainer_gpu_ids )
382+ launch_env ["CUDA_VISIBLE_DEVICES" ] = "," .join (
383+ str (gpu_id ) for gpu_id in trainer_gpu_ids
384+ )
385+ else :
386+ num_gpus = torch .cuda .device_count ()
387+ launch_env ["MODEL_IDENTIFIER" ] = self .base_model
214388
215389 command = (
216- f"{ setup_cmd } uv run --project { shlex .quote (str ( project_root )) } "
217- f"torchrun --nproc_per_node { num_gpus } { shlex .quote (str (train_script ))} "
390+ f"{ setup_cmd } { shlex .quote (sys . executable ) } -m torch.distributed.run "
391+ f"--nproc_per_node { num_gpus } { shlex .quote (str (train_script ))} "
218392 )
219393 self ._megatron_process = await asyncio .create_subprocess_shell (
220394 command ,
221395 cwd = str (project_root ),
396+ env = launch_env ,
222397 )
223398
224399 async def start_openai_server (
225400 self , config : dev .OpenAIServerConfig | None
226401 ) -> tuple [str , int ]:
227- lora_path = get_last_checkpoint_dir (self .output_dir )
228- if lora_path is None :
229- lora_path = get_step_checkpoint_dir (self .output_dir , 0 )
230- self ._latest_step = 0
231- else :
232- self ._latest_step = get_step_from_dir (self .output_dir )
233- self ._ensure_identity_lora (lora_path )
234- self ._ensure_lora_adapter_config (lora_path )
402+ lora_path = self ._resolve_active_lora_path ()
403+
404+ if self .is_dedicated :
405+ assert self .rollout_weights_mode == "lora"
406+ port = (config or {}).get ("server_args" , {}).get ("port" , 8000 )
407+ return await self ._start_vllm_subprocess (lora_path , port , config )
235408
236409 lora_path_for_server = (
237410 lora_path if self ._adapter_has_weights (lora_path ) else None
@@ -250,15 +423,94 @@ async def start_openai_server(
250423 )
251424
252425 async def vllm_engine_is_sleeping (self ) -> bool :
426+ if self .is_dedicated :
427+ return False
253428 return self ._is_sleeping
254429
430+ async def aclose (self ) -> None :
431+ self .close ()
432+
433+ def close (self ) -> None :
434+ self ._stop_vllm_subprocess ()
435+ self ._stop_megatron_process ()
436+
255437 async def train (
256438 self ,
257439 disk_packed_tensors : DiskPackedTensors ,
258440 config : types .TrainConfig ,
259441 _config : dev .TrainConfig ,
260442 verbose : bool = False ,
261443 ) -> AsyncIterator [dict [str , float ]]:
444+ if self .is_dedicated :
445+ assert self .rollout_weights_mode == "lora"
446+ await self ._ensure_megatron_running ()
447+
448+ lora_path = self ._resolve_active_lora_path ()
449+ self ._optimizer_state_path = self ._get_optimizer_state_path ()
450+
451+ jobs_dir = "/tmp/megatron_training_jobs"
452+ os .makedirs (jobs_dir , exist_ok = True )
453+ for job_name in os .listdir (jobs_dir ):
454+ if job_name .endswith (".json" ):
455+ os .remove (os .path .join (jobs_dir , job_name ))
456+ if _config .get ("moe_routing_replay_bundle" ) is not None :
457+ raise RuntimeError (
458+ "moe_routing_replay_bundle is only supported for in-process/runtime APIs; "
459+ "MegatronService subprocess jobs must use moe_routing_replay_path."
460+ )
461+ job = MegatronTrainingJob (
462+ lora_path = lora_path ,
463+ optimizer_state_path = self ._optimizer_state_path ,
464+ disk_packed_tensors = disk_packed_tensors ,
465+ config = config ,
466+ experimental_config = _config ,
467+ moe_routing_replay_path = _config .get ("moe_routing_replay_path" ),
468+ moe_routing_replay_strict = _config .get (
469+ "moe_routing_replay_strict" , True
470+ ),
471+ )
472+ job_path = os .path .join (
473+ jobs_dir , f"{ datetime .datetime .now ().isoformat ()} .json"
474+ )
475+ with open (job_path , "w" , encoding = "utf-8" ) as handle :
476+ handle .write (job .model_dump_json ())
477+
478+ num_lines = 0
479+ while True :
480+ await asyncio .sleep (0.1 )
481+ try :
482+ with open (
483+ "/tmp/megatron_training_log.jsonl" , "a+" , encoding = "utf-8"
484+ ) as log_file :
485+ log_file .seek (0 )
486+ lines = log_file .readlines ()[num_lines :]
487+ for line in lines :
488+ line = line .strip ()
489+ if not line :
490+ continue
491+ if line == "all done" :
492+ self ._merge_lora_adapter (lora_path )
493+ os .remove ("/tmp/megatron_training_log.jsonl" )
494+ break
495+ num_lines += 1
496+ yield json .loads (line )
497+ else :
498+ continue
499+ break
500+ except FileNotFoundError :
501+ continue
502+
503+ next_step = self ._latest_step + 1
504+ new_checkpoint_dir = get_step_checkpoint_dir (self .output_dir , next_step )
505+ os .makedirs (new_checkpoint_dir , exist_ok = True )
506+ shutil .copy (
507+ f"{ lora_path } /adapter_model.safetensors" ,
508+ f"{ new_checkpoint_dir } /adapter_model.safetensors" ,
509+ )
510+ self ._ensure_lora_adapter_config (new_checkpoint_dir , source_path = lora_path )
511+ await self ._reload_adapter (new_checkpoint_dir , next_step )
512+ return
513+
262514 llm = await self .llm
263515 await llm .pause_generation ()
264516 await llm .reset_prefix_cache ()
0 commit comments