Skip to content

Commit b16a5bc

Browse files
committed
feat: Add dedicated Megatron lora mode
1 parent 1905677 commit b16a5bc

File tree

3 files changed

+470
-15
lines changed

3 files changed

+470
-15
lines changed

src/art/megatron/backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
from mp_actors import move_to_child_process
24

35
from ..local.backend import LocalBackend
@@ -17,6 +19,7 @@ def __init__(
1719

1820
async def _get_service(self, model: TrainableModel) -> ModelService:
1921
from ..dev.get_model_config import get_model_config
22+
from ..dev.validate import is_dedicated_mode, validate_dedicated_config
2023
from .service import MegatronService
2124

2225
if model.name not in self._services:
@@ -25,13 +28,19 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
2528
output_dir=get_model_dir(model=model, art_path=self._path),
2629
config=model._internal_config,
2730
)
31+
validate_dedicated_config(config)
32+
dedicated = is_dedicated_mode(config)
33+
if dedicated:
34+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
35+
str(gpu_id) for gpu_id in config["trainer_gpu_ids"]
36+
)
2837
self._services[model.name] = MegatronService(
2938
model_name=model.name,
3039
base_model=model.base_model,
3140
config=config,
3241
output_dir=get_model_dir(model=model, art_path=self._path),
3342
)
34-
if not self._in_process:
43+
if not dedicated and not self._in_process:
3544
self._services[model.name] = move_to_child_process(
3645
self._services[model.name],
3746
process_name="megatron-service",

src/art/megatron/service.py

Lines changed: 266 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import asyncio
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
import datetime
44
from functools import cached_property
55
import json
6+
import logging
67
import os
78
from pathlib import Path
89
import shlex
910
import shutil
1011
import subprocess
11-
from typing import Any, AsyncIterator
12+
import sys
13+
from typing import Any, AsyncIterator, Literal
1214

1315
from peft.tuners.lora.config import LoraConfig
1416
from pydantic import BaseModel
@@ -21,6 +23,7 @@
2123

2224
from .. import dev, types
2325
from ..dev.get_model_config import default_target_modules
26+
from ..dev.validate import is_dedicated_mode
2427
from ..local.checkpoints import get_last_checkpoint_dir
2528
from ..preprocessing.pack import DiskPackedTensors
2629
from ..preprocessing.tokenize import SFTBatch
@@ -49,6 +52,9 @@ class MegatronTrainingJob(BaseModel):
4952
)
5053

5154

55+
logger = logging.getLogger(__name__)
56+
57+
5258
@dataclass
5359
class MegatronService:
5460
model_name: str
@@ -60,6 +66,24 @@ class MegatronService:
6066
_lora_id_counter: int = 1
6167
_megatron_process: asyncio.subprocess.Process | None = None
6268
_optimizer_state_path: str | None = None
69+
_vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg]
70+
_vllm_log_file: Any = field(default=None, repr=False)
71+
_vllm_host: str = "127.0.0.1"
72+
_vllm_port: int = 0
73+
74+
@property
75+
def is_dedicated(self) -> bool:
76+
return is_dedicated_mode(self.config)
77+
78+
@property
79+
def rollout_weights_mode(self) -> Literal["lora", "merged"]:
80+
mode = self.config.get("rollout_weights_mode", "lora")
81+
assert mode in {"lora", "merged"}
82+
return mode
83+
84+
@property
85+
def _vllm_base_url(self) -> str:
86+
return f"http://{self._vllm_host}:{self._vllm_port}"
6387

6488
def _next_lora_id(self) -> int:
6589
self._lora_id_counter += 1
@@ -171,6 +195,144 @@ def _ensure_lora_adapter_config(
171195
return
172196
self._default_lora_adapter_config().save_pretrained(lora_path)
173197

198+
def _resolve_active_lora_path(self) -> str:
199+
lora_path = get_last_checkpoint_dir(self.output_dir)
200+
if lora_path is None:
201+
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
202+
self._latest_step = 0
203+
else:
204+
self._latest_step = get_step_from_dir(self.output_dir)
205+
self._ensure_identity_lora(lora_path)
206+
self._ensure_lora_adapter_config(lora_path)
207+
return lora_path
208+
209+
async def _start_vllm_subprocess(
210+
self,
211+
lora_path: str,
212+
port: int,
213+
config: dev.OpenAIServerConfig | None,
214+
) -> tuple[str, int]:
215+
import atexit
216+
import httpx
217+
218+
inference_gpu_ids = self.config["inference_gpu_ids"]
219+
cuda_devices = ",".join(str(gpu_id) for gpu_id in inference_gpu_ids)
220+
221+
server_args: dict[str, object] = {
222+
"return_tokens_as_token_ids": True,
223+
"enable_auto_tool_choice": True,
224+
"tool_call_parser": "hermes",
225+
}
226+
if config and "server_args" in config:
227+
server_args.update(dict(config["server_args"]))
228+
for key in ("port", "host", "lora_modules", "api_key"):
229+
server_args.pop(key, None)
230+
231+
engine_args = dict(self.config.get("engine_args", {}))
232+
if config and "engine_args" in config:
233+
engine_args.update(dict(config["engine_args"]))
234+
engine_args.setdefault("generation_config", "vllm")
235+
engine_args["enable_lora"] = True
236+
engine_args.setdefault("max_loras", 2)
237+
for key in ("model", "served_model_name", "enable_sleep_mode"):
238+
engine_args.pop(key, None)
239+
240+
cmd = [
241+
sys.executable,
242+
"-m",
243+
"art.vllm.dedicated_server",
244+
f"--model={self.base_model}",
245+
f"--port={port}",
246+
f"--host={self._vllm_host}",
247+
f"--cuda-visible-devices={cuda_devices}",
248+
f"--lora-path={lora_path}",
249+
f"--served-model-name={self.model_name}@{self._latest_step}",
250+
f"--rollout-weights-mode={self.rollout_weights_mode}",
251+
f"--engine-args-json={json.dumps(engine_args)}",
252+
f"--server-args-json={json.dumps(server_args)}",
253+
]
254+
255+
log_dir = os.path.join(self.output_dir, "logs")
256+
os.makedirs(log_dir, exist_ok=True)
257+
self._vllm_log_file = open(
258+
os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1
259+
)
260+
self._vllm_process = subprocess.Popen(
261+
cmd,
262+
stdout=self._vllm_log_file,
263+
stderr=subprocess.STDOUT,
264+
bufsize=1,
265+
)
266+
self._vllm_port = port
267+
268+
timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600))
269+
elapsed = 0.0
270+
async with httpx.AsyncClient() as client:
271+
while elapsed < timeout:
272+
if self._vllm_process.poll() is not None:
273+
raise RuntimeError(
274+
"vLLM subprocess exited with code "
275+
f"{self._vllm_process.returncode}. "
276+
f"Check logs at {log_dir}/vllm-dedicated.log"
277+
)
278+
try:
279+
response = await client.get(
280+
f"{self._vllm_base_url}/v1/models",
281+
timeout=5.0,
282+
)
283+
if response.status_code == 200:
284+
break
285+
except (httpx.ConnectError, httpx.ReadTimeout):
286+
pass
287+
await asyncio.sleep(1.0)
288+
elapsed += 1.0
289+
else:
290+
self._stop_vllm_subprocess()
291+
raise TimeoutError(
292+
f"vLLM subprocess did not become ready within {timeout}s. "
293+
f"Check logs at {log_dir}/vllm-dedicated.log"
294+
)
295+
296+
atexit.register(self.close)
297+
logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices)
298+
return self._vllm_host, self._vllm_port
299+
300+
async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
301+
import httpx
302+
303+
async with httpx.AsyncClient() as client:
304+
response = await client.post(
305+
f"{self._vllm_base_url}/v1/load_lora_adapter",
306+
json={
307+
"lora_name": f"{self.model_name}@{step}",
308+
"lora_path": checkpoint_path,
309+
"load_inplace": True,
310+
},
311+
timeout=60.0,
312+
)
313+
response.raise_for_status()
314+
self._latest_step = step
315+
316+
def _stop_vllm_subprocess(self) -> None:
317+
if self._vllm_process is not None:
318+
self._vllm_process.terminate()
319+
try:
320+
self._vllm_process.wait(timeout=5)
321+
except subprocess.TimeoutExpired:
322+
self._vllm_process.kill()
323+
self._vllm_process.wait()
324+
self._vllm_process = None
325+
if self._vllm_log_file is not None:
326+
self._vllm_log_file.close()
327+
self._vllm_log_file = None
328+
329+
def _stop_megatron_process(self) -> None:
330+
if self._megatron_process is None:
331+
return
332+
if self._megatron_process.returncode is None:
333+
self._megatron_process.terminate()
334+
self._megatron_process = None
335+
174336
async def _add_lora_aliases(
175337
self, llm: AsyncLLM, step: int, checkpoint_dir: str
176338
) -> None:
@@ -186,6 +348,10 @@ async def _add_lora_aliases(
186348
self._latest_step = step
187349

188350
async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
351+
if self.is_dedicated:
352+
assert self.rollout_weights_mode == "lora"
353+
await self._reload_adapter(checkpoint_dir, step)
354+
return
189355
llm = await self.llm
190356
await llm.pause_generation()
191357
await self._add_lora_aliases(llm, step, checkpoint_dir)
@@ -209,29 +375,36 @@ async def _ensure_megatron_running(self) -> None:
209375
subprocess.run(["pkill", "-9", "megatron-service"], check=False)
210376
train_script = Path(__file__).parent / "train.py"
211377
project_root = Path(__file__).resolve().parents[3]
212-
num_gpus = torch.cuda.device_count()
213-
os.environ["MODEL_IDENTIFIER"] = self.base_model
378+
launch_env = os.environ.copy()
379+
if self.is_dedicated:
380+
trainer_gpu_ids = self.config["trainer_gpu_ids"]
381+
num_gpus = len(trainer_gpu_ids)
382+
launch_env["CUDA_VISIBLE_DEVICES"] = ",".join(
383+
str(gpu_id) for gpu_id in trainer_gpu_ids
384+
)
385+
else:
386+
num_gpus = torch.cuda.device_count()
387+
launch_env["MODEL_IDENTIFIER"] = self.base_model
214388

215389
command = (
216-
f"{setup_cmd}uv run --project {shlex.quote(str(project_root))} "
217-
f"torchrun --nproc_per_node {num_gpus} {shlex.quote(str(train_script))}"
390+
f"{setup_cmd}{shlex.quote(sys.executable)} -m torch.distributed.run "
391+
f"--nproc_per_node {num_gpus} {shlex.quote(str(train_script))}"
218392
)
219393
self._megatron_process = await asyncio.create_subprocess_shell(
220394
command,
221395
cwd=str(project_root),
396+
env=launch_env,
222397
)
223398

224399
async def start_openai_server(
225400
self, config: dev.OpenAIServerConfig | None
226401
) -> tuple[str, int]:
227-
lora_path = get_last_checkpoint_dir(self.output_dir)
228-
if lora_path is None:
229-
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
230-
self._latest_step = 0
231-
else:
232-
self._latest_step = get_step_from_dir(self.output_dir)
233-
self._ensure_identity_lora(lora_path)
234-
self._ensure_lora_adapter_config(lora_path)
402+
lora_path = self._resolve_active_lora_path()
403+
404+
if self.is_dedicated:
405+
assert self.rollout_weights_mode == "lora"
406+
port = (config or {}).get("server_args", {}).get("port", 8000)
407+
return await self._start_vllm_subprocess(lora_path, port, config)
235408

236409
lora_path_for_server = (
237410
lora_path if self._adapter_has_weights(lora_path) else None
@@ -250,15 +423,94 @@ async def start_openai_server(
250423
)
251424

252425
async def vllm_engine_is_sleeping(self) -> bool:
426+
if self.is_dedicated:
427+
return False
253428
return self._is_sleeping
254429

430+
async def aclose(self) -> None:
431+
self.close()
432+
433+
def close(self) -> None:
434+
self._stop_vllm_subprocess()
435+
self._stop_megatron_process()
436+
255437
async def train(
256438
self,
257439
disk_packed_tensors: DiskPackedTensors,
258440
config: types.TrainConfig,
259441
_config: dev.TrainConfig,
260442
verbose: bool = False,
261443
) -> AsyncIterator[dict[str, float]]:
444+
if self.is_dedicated:
445+
assert self.rollout_weights_mode == "lora"
446+
await self._ensure_megatron_running()
447+
448+
lora_path = self._resolve_active_lora_path()
449+
self._optimizer_state_path = self._get_optimizer_state_path()
450+
451+
jobs_dir = "/tmp/megatron_training_jobs"
452+
os.makedirs(jobs_dir, exist_ok=True)
453+
for job_name in os.listdir(jobs_dir):
454+
if job_name.endswith(".json"):
455+
os.remove(os.path.join(jobs_dir, job_name))
456+
if _config.get("moe_routing_replay_bundle") is not None:
457+
raise RuntimeError(
458+
"moe_routing_replay_bundle is only supported for in-process/runtime APIs; "
459+
"MegatronService subprocess jobs must use moe_routing_replay_path."
460+
)
461+
job = MegatronTrainingJob(
462+
lora_path=lora_path,
463+
optimizer_state_path=self._optimizer_state_path,
464+
disk_packed_tensors=disk_packed_tensors,
465+
config=config,
466+
experimental_config=_config,
467+
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
468+
moe_routing_replay_strict=_config.get(
469+
"moe_routing_replay_strict", True
470+
),
471+
)
472+
job_path = os.path.join(
473+
jobs_dir, f"{datetime.datetime.now().isoformat()}.json"
474+
)
475+
with open(job_path, "w", encoding="utf-8") as handle:
476+
handle.write(job.model_dump_json())
477+
478+
num_lines = 0
479+
while True:
480+
await asyncio.sleep(0.1)
481+
try:
482+
with open(
483+
"/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8"
484+
) as log_file:
485+
log_file.seek(0)
486+
lines = log_file.readlines()[num_lines:]
487+
for line in lines:
488+
line = line.strip()
489+
if not line:
490+
continue
491+
if line == "all done":
492+
self._merge_lora_adapter(lora_path)
493+
os.remove("/tmp/megatron_training_log.jsonl")
494+
break
495+
num_lines += 1
496+
yield json.loads(line)
497+
else:
498+
continue
499+
break
500+
except FileNotFoundError:
501+
continue
502+
503+
next_step = self._latest_step + 1
504+
new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step)
505+
os.makedirs(new_checkpoint_dir, exist_ok=True)
506+
shutil.copy(
507+
f"{lora_path}/adapter_model.safetensors",
508+
f"{new_checkpoint_dir}/adapter_model.safetensors",
509+
)
510+
self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path)
511+
await self._reload_adapter(new_checkpoint_dir, next_step)
512+
return
513+
262514
llm = await self.llm
263515
await llm.pause_generation()
264516
await llm.reset_prefix_cache()

0 commit comments

Comments
 (0)