Skip to content

Commit fb26124

Browse files
authored
feat: Add merged vLLM rollout weights (#631)
1 parent eb6a50d commit fb26124

14 files changed

Lines changed: 641 additions & 45 deletions

dev/run_qwen3_5_localbackend_yes_no_maybe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def _format_int_list(values: list[int]) -> str:
5151
parser.add_argument(
5252
"--enable-thinking", action=argparse.BooleanOptionalAction, default=False
5353
)
54+
parser.add_argument(
55+
"--rollout-weights-mode",
56+
choices=("lora", "merged"),
57+
default=None,
58+
)
5459
parser.add_argument("--trainer-gpu-ids", type=int, nargs="+")
5560
parser.add_argument("--inference-gpu-ids", type=int, nargs="+")
5661
args = parser.parse_args()
@@ -98,6 +103,8 @@ def _format_int_list(values: list[int]) -> str:
98103
f"INFERENCE_GPU_IDS={_format_int_list(args.inference_gpu_ids)}",
99104
]
100105
)
106+
if args.rollout_weights_mode is not None:
107+
env.append(f"ROLLOUT_WEIGHTS_MODE={args.rollout_weights_mode}")
101108
env_block = " \\\n ".join(env)
102109

103110
run_script = textwrap.dedent(
@@ -143,6 +150,7 @@ def _format_int_list(values: list[int]) -> str:
143150
print(f" load_in_4bit: {args.load_in_4bit}")
144151
print(f" load_in_16bit: {args.load_in_16bit}")
145152
print(f" enable_thinking: {args.enable_thinking}")
153+
print(f" rollout_weights_mode: {args.rollout_weights_mode}")
146154
print(f" trainer_gpu_ids: {args.trainer_gpu_ids}")
147155
print(f" inference_gpu_ids: {args.inference_gpu_ids}")
148156

dev/yes-no-maybe-metrics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def build_internal_config() -> art.dev.InternalModelConfig:
223223
result["trainer_gpu_ids"] = trainer_gpu_ids
224224
result["inference_gpu_ids"] = inference_gpu_ids
225225

226+
rollout_weights_mode = os.environ.get("ROLLOUT_WEIGHTS_MODE")
227+
if rollout_weights_mode is not None:
228+
if rollout_weights_mode not in {"lora", "merged"}:
229+
raise ValueError("ROLLOUT_WEIGHTS_MODE must be either 'lora' or 'merged'")
230+
result["rollout_weights_mode"] = rollout_weights_mode
231+
226232
return result
227233

228234

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ backend = [
3838
"pytest>=8.4.1",
3939
"nbmake>=1.5.5",
4040
"gql<4",
41+
"nvidia-cudnn-frontend<1.21 ; sys_platform == 'linux'",
4142
"vllm @ https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl ; sys_platform == 'linux'",
4243
]
4344
megatron = [

src/art/dev/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from typing_extensions import TypedDict
44

55

6+
class WeightTransferConfig(TypedDict):
7+
backend: Literal["nccl"]
8+
9+
610
class EngineArgs(TypedDict, total=False):
711
model: str
812
served_model_name: str | list[str] | None
@@ -124,6 +128,7 @@ class EngineArgs(TypedDict, total=False):
124128
calculate_kv_scales: bool | None
125129

126130
additional_config: dict[str, Any] | None
131+
weight_transfer_config: WeightTransferConfig | None
127132

128133
disable_log_requests: (
129134
bool # Deprecated in vLLM 0.13+, use enable_log_requests instead

src/art/dev/get_model_config.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
from .engine import EngineArgs
22
from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs
3-
from .validate import is_dedicated_mode
3+
from .validate import QWEN3_5_MOE_MODELS, is_dedicated_mode
4+
5+
6+
def default_target_modules(base_model: str) -> list[str]:
7+
if base_model in QWEN3_5_MOE_MODELS:
8+
return [
9+
"q_proj",
10+
"k_proj",
11+
"v_proj",
12+
"o_proj",
13+
"in_proj_qkv",
14+
"in_proj_z",
15+
"out_proj",
16+
"gate_proj",
17+
"up_proj",
18+
"down_proj",
19+
]
20+
return [
21+
"q_proj",
22+
"k_proj",
23+
"v_proj",
24+
"o_proj",
25+
"gate_proj",
26+
"up_proj",
27+
"down_proj",
28+
]
429

530

631
def get_model_config(
@@ -14,6 +39,7 @@ def get_model_config(
1439
config = InternalModelConfig()
1540

1641
dedicated = is_dedicated_mode(config)
42+
rollout_weights_mode = config.get("rollout_weights_mode", "lora")
1743

1844
if dedicated:
1945
enable_sleep_mode = False
@@ -43,15 +69,7 @@ def get_model_config(
4369
lora_alpha=16,
4470
r=8,
4571
random_state=3407,
46-
target_modules=[
47-
"q_proj",
48-
"k_proj",
49-
"v_proj",
50-
"o_proj",
51-
"gate_proj",
52-
"up_proj",
53-
"down_proj",
54-
],
72+
target_modules=default_target_modules(base_model),
5573
use_gradient_checkpointing="unsloth",
5674
)
5775
peft_args.update(config.get("peft_args", {}))
@@ -78,6 +96,7 @@ def get_model_config(
7896
init_args=init_args,
7997
engine_args=engine_args,
8098
peft_args=peft_args,
99+
rollout_weights_mode=rollout_weights_mode,
81100
tinker_args=config.get("tinker_args"),
82101
trainer_args=trainer_args,
83102
)

src/art/dev/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from enum import Enum
2+
from typing import Literal
23

34
from typing_extensions import Required, TypedDict
45

56
from .engine import EngineArgs
67

8+
RolloutWeightsMode = Literal["lora", "merged"]
9+
710

811
# Vendored from transformers.training_args.OptimizerNames
912
class OptimizerNames(str, Enum):
@@ -120,6 +123,10 @@ class InternalModelConfig(TypedDict, total=False):
120123
inference run on separate GPUs.
121124
inference_gpu_ids: GPU IDs for vLLM inference (e.g., [1]). When set
122125
with trainer_gpu_ids, enables dedicated mode.
126+
rollout_weights_mode: How inference weights are applied in vLLM.
127+
- "lora": load LoRA adapters into vLLM directly
128+
- "merged": keep training LoRA adapters, but push merged weights
129+
into vLLM for inference
123130
"""
124131

125132
init_args: "InitArgs"
@@ -130,6 +137,7 @@ class InternalModelConfig(TypedDict, total=False):
130137
trainer_args: "TrainerArgs"
131138
trainer_gpu_ids: list[int]
132139
inference_gpu_ids: list[int]
140+
rollout_weights_mode: "RolloutWeightsMode"
133141

134142

135143
class TinkerArgs(TypedDict, total=False):

src/art/dev/validate.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
"""Validation functions for model configuration."""
22

3-
from .model import InternalModelConfig
3+
from .model import InternalModelConfig, RolloutWeightsMode
4+
5+
QWEN3_5_MOE_MODELS = {
6+
"Qwen/Qwen3.5-35B-A3B",
7+
"Qwen/Qwen3.5-397B-A17B",
8+
}
49

510

611
def is_dedicated_mode(config: InternalModelConfig) -> bool:
712
"""Return True if the config specifies dedicated mode (separate training and inference GPUs)."""
813
return "trainer_gpu_ids" in config and "inference_gpu_ids" in config
914

1015

16+
def _rollout_weights_mode(config: InternalModelConfig) -> RolloutWeightsMode:
17+
mode = config.get("rollout_weights_mode", "lora")
18+
if mode in {"lora", "merged"}:
19+
return mode
20+
raise ValueError("rollout_weights_mode must be either 'lora' or 'merged'")
21+
22+
23+
def _is_qwen3_5_moe_model(config: InternalModelConfig) -> bool:
24+
model_name = config.get("engine_args", {}).get("model")
25+
return model_name in QWEN3_5_MOE_MODELS
26+
27+
1128
def validate_dedicated_config(config: InternalModelConfig) -> None:
1229
"""Validate dedicated mode GPU configuration.
1330
@@ -16,12 +33,19 @@ def validate_dedicated_config(config: InternalModelConfig) -> None:
1633
"""
1734
has_trainer = "trainer_gpu_ids" in config
1835
has_inference = "inference_gpu_ids" in config
36+
rollout_weights_mode = _rollout_weights_mode(config)
1937

2038
if has_trainer != has_inference:
2139
raise ValueError(
2240
"trainer_gpu_ids and inference_gpu_ids must both be set or both unset"
2341
)
2442

43+
if rollout_weights_mode == "merged" and not has_trainer:
44+
raise ValueError(
45+
"rollout_weights_mode='merged' requires dedicated mode "
46+
"(set both trainer_gpu_ids and inference_gpu_ids)"
47+
)
48+
2549
if not has_trainer:
2650
return
2751

@@ -65,3 +89,9 @@ def validate_dedicated_config(config: InternalModelConfig) -> None:
6589
"enable_sleep_mode is incompatible with dedicated mode "
6690
"(dedicated mode runs vLLM on a separate GPU, sleep/wake is not needed)"
6791
)
92+
93+
if _is_qwen3_5_moe_model(config) and rollout_weights_mode == "lora":
94+
raise ValueError(
95+
"Qwen3.5-MoE models require rollout_weights_mode='merged' with the "
96+
"current vLLM version because direct LoRA inference is currently broken"
97+
)

src/art/megatron/service.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.v1.engine.async_llm import AsyncLLM
2020

2121
from .. import dev, types
22+
from ..dev.get_model_config import default_target_modules
2223
from ..local.checkpoints import get_last_checkpoint_dir
2324
from ..preprocessing.pack import DiskPackedTensors
2425
from ..preprocessing.tokenize import SFTBatch
@@ -66,15 +67,7 @@ def _default_lora_adapter_config(self) -> LoraConfig:
6667
return LoraConfig(
6768
r=1,
6869
lora_alpha=32,
69-
target_modules=[
70-
"q_proj",
71-
"k_proj",
72-
"v_proj",
73-
"o_proj",
74-
"gate_proj",
75-
"up_proj",
76-
"down_proj",
77-
],
70+
target_modules=default_target_modules(self.base_model),
7871
bias="none",
7972
)
8073

0 commit comments

Comments
 (0)