Skip to content

Commit b14deb6

Browse files
lilyz-aiclaude
andcommitted
feat: make HF weights sync non-blocking with K8s init container
ensure_model_weights_available is now synchronous — it returns the expected checkpoint path immediately and fires a background asyncio task to sync weights from HuggingFace Hub. An init container is injected into the K8s deployment to poll storage until the weights are present before the main container starts. LLMMetadata gains an hf_weights_syncing flag to signal this flow downstream. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3e7a2a4 commit b14deb6

5 files changed

Lines changed: 114 additions & 34 deletions

File tree

model-engine/model_engine_server/domain/entities/llm_entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ class LLMMetadata:
3131
quantize: Optional[Quantization] = None
3232
checkpoint_path: Optional[str] = None
3333
chat_template_override: Optional[str] = None
34+
hf_weights_syncing: bool = False

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,8 @@ def load_model_weights_sub_commands_s3(
652652
s5cmd = "./s5cmd"
653653

654654
checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path)
655-
validate_checkpoint_files(checkpoint_files)
655+
if checkpoint_files:
656+
validate_checkpoint_files(checkpoint_files)
656657

657658
# filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
658659
# For models that are not supported by transformers directly, we need to include '*.py' and '*.bin'
@@ -1389,18 +1390,20 @@ async def execute(
13891390
"Multinode endpoints are only supported for VLLM models."
13901391
)
13911392

1392-
# Resolve checkpoint path: auto-download from HF Hub to remote storage if not cached
1393+
# Resolve checkpoint path: fires background sync and returns expected path immediately
13931394
checkpoint_path = request.checkpoint_path
1395+
hf_weights_syncing = False
13941396
if (
13951397
checkpoint_path is None
13961398
and request.source == LLMSource.HUGGING_FACE
13971399
and self.model_weights_manager is not None
13981400
):
13991401
models_info = SUPPORTED_MODELS_INFO.get(request.model_name)
14001402
if models_info and models_info.hf_repo:
1401-
checkpoint_path = await self.model_weights_manager.ensure_model_weights_available(
1402-
hf_repo=models_info.hf_repo
1403+
checkpoint_path = self.model_weights_manager.ensure_model_weights_available(
1404+
models_info.hf_repo
14031405
)
1406+
hf_weights_syncing = True
14041407

14051408
bundle = await self.create_llm_model_bundle_use_case.execute(
14061409
user,
@@ -1447,6 +1450,7 @@ async def execute(
14471450
quantize=request.quantize,
14481451
checkpoint_path=checkpoint_path,
14491452
chat_template_override=request.chat_template_override,
1453+
hf_weights_syncing=hf_weights_syncing,
14501454
)
14511455
)
14521456

model-engine/model_engine_server/domain/use_cases/model_weights_manager.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,35 @@ class ModelWeightsManager:
2525
def __init__(self, llm_artifact_gateway: LLMArtifactGateway):
2626
self.llm_artifact_gateway = llm_artifact_gateway
2727

28-
def _get_remote_path(self, hf_repo: str) -> str:
28+
def get_remote_path(self, hf_repo: str) -> str:
2929
prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/")
3030
return f"{prefix}/{hf_repo}"
3131

32-
async def ensure_model_weights_available(self, hf_repo: str) -> str:
32+
def ensure_model_weights_available(self, hf_repo: str) -> str:
3333
"""
34-
Ensures model weights for ``hf_repo`` are available at the configured remote path.
34+
Returns the expected remote path for ``hf_repo`` immediately and starts
35+
syncing weights from HuggingFace Hub to that path in the background.
3536
36-
If the weights are already cached (remote path is non-empty), returns immediately.
37-
Otherwise downloads from HuggingFace Hub and uploads to the remote path.
37+
If the weights are already cached the background task exits early.
38+
Callers receive the checkpoint path right away and can proceed with
39+
any following actions (e.g. endpoint creation) without blocking.
3840
3941
Args:
4042
hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``.
4143
4244
Returns:
43-
The remote path (s3://, gs://, or https://) where the weights are stored.
45+
The remote path (s3://, gs://, or https://) where the weights will be stored.
4446
"""
45-
remote_path = self._get_remote_path(hf_repo)
47+
remote_path = self.get_remote_path(hf_repo)
48+
asyncio.create_task(self._sync_weights(hf_repo, remote_path))
49+
return remote_path
50+
51+
async def _sync_weights(self, hf_repo: str, remote_path: str) -> None:
52+
"""Downloads weights from HuggingFace Hub and uploads to remote storage if not cached."""
4653
files = self.llm_artifact_gateway.list_files(remote_path)
4754
if files:
4855
logger.info(f"Cache hit: {len(files)} files at {remote_path}")
49-
return remote_path
56+
return
5057

5158
logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...")
5259
loop = asyncio.get_event_loop()
@@ -70,4 +77,3 @@ async def ensure_model_weights_available(self, hf_repo: str) -> str:
7077
)
7178

7279
logger.info(f"Weights for {hf_repo} uploaded to {remote_path}")
73-
return remote_path

model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@
6262
BASE_PATH_IN_ENDPOINT = "/app"
6363

6464
DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"}
65+
66+
# Key under which LLM metadata is stored in model_endpoint_record.metadata
67+
_LLM_METADATA_KEY = "_llm"
68+
69+
# Python script run by the init container to poll storage until HF weights are present.
70+
_HF_WEIGHTS_POLL_SCRIPT = """\
71+
import boto3, os, sys, time
72+
from urllib.parse import urlparse
73+
74+
cp = os.environ["CHECKPOINT_PATH"]
75+
url = urlparse(cp)
76+
bucket = url.netloc
77+
prefix = url.path.lstrip("/")
78+
s3 = boto3.client("s3")
79+
while True:
80+
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
81+
if resp.get("Contents"):
82+
print(f"Model weights ready at {cp}", flush=True)
83+
sys.exit(0)
84+
print(f"Waiting for model weights at {cp}...", flush=True)
85+
time.sleep(30)
86+
"""
6587
LWS_DEFAULT_ENV_VAR = {
6688
"K8S_OWN_POD_NAME",
6789
"K8S_OWN_NAMESPACE",
@@ -339,6 +361,42 @@ def add_pod_metadata_env_to_container(container: Dict[str, Any]) -> None:
339361
)
340362

341363

364+
def add_hf_weights_init_container(
365+
deployment_template: Dict[str, Any],
366+
checkpoint_path: str,
367+
) -> None:
368+
"""Prepend an init container that polls storage until HF weights are present.
369+
370+
Uses the forwarder image (model-engine gateway image, which has Python and
371+
boto3) so no additional image pull is required. Authentication relies on
372+
the pod's service account (IRSA / workload-identity).
373+
"""
374+
containers = deployment_template["spec"]["template"]["spec"]["containers"]
375+
# Prefer the forwarder container image; fall back to the first container.
376+
forwarder_image = next(
377+
(c["image"] for c in containers if c["name"] in ("http-forwarder", "celery-forwarder")),
378+
containers[0]["image"],
379+
)
380+
381+
init_container: Dict[str, Any] = {
382+
"name": "wait-for-model-weights",
383+
"image": forwarder_image,
384+
"env": [{"name": "CHECKPOINT_PATH", "value": checkpoint_path}],
385+
"command": ["python3", "-c", _HF_WEIGHTS_POLL_SCRIPT],
386+
}
387+
388+
# Reuse the AWS config volume mount if the volume is present in the pod spec
389+
volumes = deployment_template["spec"]["template"]["spec"].get("volumes", [])
390+
if any(v["name"] == "config-volume" for v in volumes):
391+
init_container["volumeMounts"] = [
392+
{"name": "config-volume", "mountPath": "/opt/.aws/config", "subPath": "config"}
393+
]
394+
395+
deployment_template["spec"]["template"]["spec"].setdefault("initContainers", []).append(
396+
init_container
397+
)
398+
399+
342400
def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None:
343401
container_envs = []
344402
container_envs.extend(
@@ -1657,6 +1715,9 @@ async def _create_or_update_resources(
16571715
user_container = get_main_container_from_deployment_template(deployment_template)
16581716
add_datadog_env_to_container(deployment_template, user_container)
16591717
add_pod_metadata_env_to_container(user_container)
1718+
llm_metadata = (model_endpoint_record.metadata or {}).get(_LLM_METADATA_KEY, {})
1719+
if llm_metadata.get("hf_weights_syncing") and llm_metadata.get("checkpoint_path"):
1720+
add_hf_weights_init_container(deployment_template, llm_metadata["checkpoint_path"])
16601721
await self._create_deployment(
16611722
model_endpoint_record=request.build_endpoint_request.model_endpoint_record,
16621723
deployment=deployment_template,

model-engine/tests/unit/domain/test_model_weights_manager.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,14 @@ async def test_cache_hit_skips_download():
4040
gateway = FakeArtifactGateway(existing_files=["model.safetensors"])
4141
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
4242

43-
with patch(
44-
"model_engine_server.domain.use_cases.model_weights_manager.snapshot_download"
45-
) as mock_download:
46-
result = await manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B")
43+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
44+
with (
45+
patch(f"{mwm_base}.snapshot_download") as mock_download,
46+
patch(f"{mwm_base}.asyncio.create_task") as mock_create_task,
47+
):
48+
result = manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B")
49+
# Run the background sync task to assert on side-effects
50+
await mock_create_task.call_args[0][0]
4751

4852
mock_download.assert_not_called()
4953
assert len(gateway.uploaded) == 0
@@ -60,8 +64,10 @@ async def test_cache_hit_returns_correct_s3_path(monkeypatch):
6064
gateway = FakeArtifactGateway(existing_files=["file.bin"])
6165
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
6266

63-
with patch("model_engine_server.domain.use_cases.model_weights_manager.snapshot_download"):
64-
result = await manager.ensure_model_weights_available("org/model")
67+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
68+
with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task:
69+
result = manager.ensure_model_weights_available("org/model")
70+
await mock_create_task.call_args[0][0]
6571

6672
assert result == "s3://my-bucket/weights/org/model"
6773

@@ -77,10 +83,14 @@ async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypat
7783
gateway = FakeArtifactGateway(existing_files=[])
7884
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
7985

80-
with patch(
81-
"model_engine_server.domain.use_cases.model_weights_manager.snapshot_download"
82-
) as mock_download:
83-
result = await manager.ensure_model_weights_available("org/model")
86+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
87+
with (
88+
patch(f"{mwm_base}.snapshot_download") as mock_download,
89+
patch(f"{mwm_base}.asyncio.create_task") as mock_create_task,
90+
):
91+
result = manager.ensure_model_weights_available("org/model")
92+
# Run the background sync task so we can assert on its side-effects
93+
await mock_create_task.call_args[0][0]
8494

8595
mock_download.assert_called_once()
8696
call_kwargs = mock_download.call_args
@@ -93,8 +103,7 @@ async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypat
93103
assert result == "s3://my-bucket/weights/org/model"
94104

95105

96-
@pytest.mark.asyncio
97-
async def test_s3_path_construction(monkeypatch):
106+
def test_s3_path_construction(monkeypatch):
98107
"""Remote path should be {prefix}/{hf_repo} with correct stripping of trailing slash."""
99108
monkeypatch.setattr(
100109
"model_engine_server.domain.use_cases.model_weights_manager.hmi_config",
@@ -103,27 +112,27 @@ async def test_s3_path_construction(monkeypatch):
103112
gateway = FakeArtifactGateway(existing_files=[])
104113
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
105114

106-
path = manager._get_remote_path("myorg/mymodel")
115+
path = manager.get_remote_path("myorg/mymodel")
107116
assert path == "s3://bucket/prefix/myorg/mymodel"
108117

109118

110119
@pytest.mark.asyncio
111120
async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source():
112-
"""CreateLLMModelEndpointV1UseCase should call model_weights_manager when source is HF and checkpoint_path is None."""
121+
"""CreateLLMModelEndpointV1UseCase should call ensure_model_weights_available (sync),
122+
which returns the expected checkpoint path immediately and fires weight sync in the
123+
background. All following actions (bundle, endpoint creation) proceed without blocking."""
113124
from model_engine_server.domain.entities import LLMSource
114125
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
115126

116127
mock_manager = MagicMock(spec=ModelWeightsManager)
117-
mock_manager.ensure_model_weights_available = AsyncMock(
118-
return_value="s3://bucket/weights/huggyllama/llama-7b"
128+
mock_manager.ensure_model_weights_available.return_value = (
129+
"s3://bucket/weights/huggyllama/llama-7b"
119130
)
120131

121132
# Use a real SUPPORTED_MODELS_INFO entry: "llama-2-7b" -> "huggyllama/llama-7b"
122133
from tests.unit.conftest import FakeLLMArtifactGateway
123134

124135
fake_gateway = FakeLLMArtifactGateway()
125-
# Ensure the resolved checkpoint path is found in the fake bucket
126-
fake_gateway.s3_bucket["s3://bucket/weights/huggyllama/llama-7b"] = ["model.safetensors"]
127136

128137
from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import (
129138
CreateLLMModelEndpointV1UseCase,
@@ -204,9 +213,8 @@ async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source():
204213
mock_authz.return_value.get_s3_bucket_for_user = MagicMock(return_value="test-bucket")
205214
await use_case.execute(user=user, request=request)
206215

207-
mock_manager.ensure_model_weights_available.assert_called_once_with(
208-
hf_repo="huggyllama/llama-7b"
209-
)
216+
# ensure_model_weights_available is called synchronously — no await, no blocking
217+
mock_manager.ensure_model_weights_available.assert_called_once_with("huggyllama/llama-7b")
210218
# Verify that the resolved checkpoint path was forwarded to the bundle use case
211219
bundle_call_kwargs = mock_bundle_use_case.execute.call_args.kwargs
212220
assert bundle_call_kwargs["checkpoint_path"] == "s3://bucket/weights/huggyllama/llama-7b"

0 commit comments

Comments
 (0)