Skip to content

Commit 3e7a2a4

Browse files
lilyz-aiclaude
andcommitted
feat: add ModelWeightsManager to auto-sync HF weights on endpoint creation
When a model endpoint is created via POST /v1/llm/model-endpoints with source=HUGGING_FACE and no checkpoint_path, ModelWeightsManager now automatically checks the configured S3/GCS/ABS prefix for cached weights and downloads from HuggingFace Hub + uploads if missing — eliminating the manual sync_model_weights.py step. - Add ModelWeightsManager with ensure_model_weights_available() (async-safe via run_in_executor, cache-hit skips all I/O) - Add upload_files() abstract method to LLMArtifactGateway with implementations for S3, GCS, and ABS - Wire ModelWeightsManager into CreateLLMModelEndpointV1UseCase and the create_model_endpoint API handler - Fix huggingface_hub.utils._errors import for hub>=0.36 compatibility - Add unit tests covering cache hit/miss, path construction, and end-to-end integration with CreateLLMModelEndpointV1UseCase Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4a90a55 commit 3e7a2a4

10 files changed

Lines changed: 357 additions & 3 deletions

File tree

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
UpdateLLMModelEndpointV1UseCase,
8787
)
8888
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
89+
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
8990
from pydantic import RootModel
9091
from sse_starlette.sse import EventSourceResponse
9192

@@ -168,11 +169,15 @@ async def create_model_endpoint(
168169
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
169170
docker_repository=external_interfaces.docker_repository,
170171
)
172+
model_weights_manager = ModelWeightsManager(
173+
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
174+
)
171175
use_case = CreateLLMModelEndpointV1UseCase(
172176
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
173177
model_endpoint_service=external_interfaces.model_endpoint_service,
174178
docker_repository=external_interfaces.docker_repository,
175179
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
180+
model_weights_manager=model_weights_manager,
176181
)
177182
return await use_case.execute(user=auth, request=request)
178183
except ObjectAlreadyExistsException as exc:

model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[
4040
"""
4141
pass
4242

43+
@abstractmethod
44+
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
45+
"""
46+
Upload all files from a local directory to a remote path.
47+
48+
Args:
49+
local_path (str): local directory containing files to upload
50+
remote_path (str): remote destination path (s3://, gs://, or https://)
51+
"""
52+
pass
53+
4354
@abstractmethod
4455
def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
4556
"""

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,12 +1322,14 @@ def __init__(
13221322
model_endpoint_service: ModelEndpointService,
13231323
docker_repository: DockerRepository,
13241324
llm_artifact_gateway: LLMArtifactGateway,
1325+
model_weights_manager=None,
13251326
):
13261327
self.authz_module = LiveAuthorizationModule()
13271328
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
13281329
self.model_endpoint_service = model_endpoint_service
13291330
self.docker_repository = docker_repository
13301331
self.llm_artifact_gateway = llm_artifact_gateway
1332+
self.model_weights_manager = model_weights_manager
13311333

13321334
async def execute(
13331335
self, user: User, request: CreateLLMModelEndpointV1Request
@@ -1387,6 +1389,19 @@ async def execute(
13871389
"Multinode endpoints are only supported for VLLM models."
13881390
)
13891391

1392+
# Resolve checkpoint path: auto-download from HF Hub to remote storage if not cached
1393+
checkpoint_path = request.checkpoint_path
1394+
if (
1395+
checkpoint_path is None
1396+
and request.source == LLMSource.HUGGING_FACE
1397+
and self.model_weights_manager is not None
1398+
):
1399+
models_info = SUPPORTED_MODELS_INFO.get(request.model_name)
1400+
if models_info and models_info.hf_repo:
1401+
checkpoint_path = await self.model_weights_manager.ensure_model_weights_available(
1402+
hf_repo=models_info.hf_repo
1403+
)
1404+
13901405
bundle = await self.create_llm_model_bundle_use_case.execute(
13911406
user,
13921407
endpoint_name=request.name,
@@ -1397,7 +1412,7 @@ async def execute(
13971412
endpoint_type=request.endpoint_type,
13981413
num_shards=request.num_shards,
13991414
quantize=request.quantize,
1400-
checkpoint_path=request.checkpoint_path,
1415+
checkpoint_path=checkpoint_path,
14011416
chat_template_override=request.chat_template_override,
14021417
nodes_per_worker=request.nodes_per_worker,
14031418
additional_args=request.model_dump(exclude_none=True),
@@ -1430,7 +1445,7 @@ async def execute(
14301445
inference_framework_image_tag=request.inference_framework_image_tag,
14311446
num_shards=request.num_shards,
14321447
quantize=request.quantize,
1433-
checkpoint_path=request.checkpoint_path,
1448+
checkpoint_path=checkpoint_path,
14341449
chat_template_override=request.chat_template_override,
14351450
)
14361451
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import asyncio
2+
import functools
3+
import tempfile
4+
from typing import List
5+
6+
from huggingface_hub import snapshot_download
7+
from model_engine_server.common.config import hmi_config
8+
from model_engine_server.core.loggers import logger_name, make_logger
9+
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway
10+
11+
logger = make_logger(logger_name())
12+
13+
# Match the internal sync_model_weights.py inclusion/exclusion patterns
14+
HF_IGNORE_PATTERNS: List[str] = [
15+
"optimizer*",
16+
"*.msgpack",
17+
"*.h5",
18+
"flax_model*",
19+
"tf_model*",
20+
"rust_model*",
21+
]
22+
23+
24+
class ModelWeightsManager:
25+
def __init__(self, llm_artifact_gateway: LLMArtifactGateway):
26+
self.llm_artifact_gateway = llm_artifact_gateway
27+
28+
def _get_remote_path(self, hf_repo: str) -> str:
29+
prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/")
30+
return f"{prefix}/{hf_repo}"
31+
32+
async def ensure_model_weights_available(self, hf_repo: str) -> str:
33+
"""
34+
Ensures model weights for ``hf_repo`` are available at the configured remote path.
35+
36+
If the weights are already cached (remote path is non-empty), returns immediately.
37+
Otherwise downloads from HuggingFace Hub and uploads to the remote path.
38+
39+
Args:
40+
hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``.
41+
42+
Returns:
43+
The remote path (s3://, gs://, or https://) where the weights are stored.
44+
"""
45+
remote_path = self._get_remote_path(hf_repo)
46+
files = self.llm_artifact_gateway.list_files(remote_path)
47+
if files:
48+
logger.info(f"Cache hit: {len(files)} files at {remote_path}")
49+
return remote_path
50+
51+
logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...")
52+
loop = asyncio.get_event_loop()
53+
with tempfile.TemporaryDirectory() as tmp_dir:
54+
await loop.run_in_executor(
55+
None,
56+
functools.partial(
57+
snapshot_download,
58+
repo_id=hf_repo,
59+
local_dir=tmp_dir,
60+
ignore_patterns=HF_IGNORE_PATTERNS,
61+
),
62+
)
63+
await loop.run_in_executor(
64+
None,
65+
functools.partial(
66+
self.llm_artifact_gateway.upload_files,
67+
tmp_dir,
68+
remote_path,
69+
),
70+
)
71+
72+
logger.info(f"Weights for {hf_repo} uploaded to {remote_path}")
73+
return remote_path

model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
5959
downloaded_files.append(local_path)
6060
return downloaded_files
6161

62+
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
63+
parsed = parse_attachment_url(remote_path, clean_key=False)
64+
container_client = _get_abs_container_client(parsed.bucket)
65+
for root, _, files in os.walk(local_path):
66+
for file in files:
67+
local_file = os.path.join(root, file)
68+
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
69+
with open(local_file, "rb") as f:
70+
container_client.upload_blob(name=blob_name, data=f, overwrite=True)
71+
6272
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
6373
parsed_remote = parse_attachment_url(
6474
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False

model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
5252
downloaded_files.append(local_path)
5353
return downloaded_files
5454

55+
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
56+
parsed = parse_attachment_url(remote_path, clean_key=False)
57+
client = get_gcs_sync_client()
58+
bucket = client.bucket(parsed.bucket)
59+
for root, _, files in os.walk(local_path):
60+
for file in files:
61+
local_file = os.path.join(root, file)
62+
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
63+
bucket.blob(blob_name).upload_from_filename(local_file)
64+
5565
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
5666
parsed_remote = parse_attachment_url(
5767
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False

model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
5656
logger.info(f"Downloaded {len(downloaded_files)} files to {target_path}")
5757
return downloaded_files
5858

59+
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
60+
s3 = get_s3_resource(kwargs)
61+
parsed = parse_attachment_url(remote_path, clean_key=False)
62+
bucket = s3.Bucket(parsed.bucket)
63+
for root, _, files in os.walk(local_path):
64+
for file in files:
65+
local_file = os.path.join(root, file)
66+
s3_key = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
67+
logger.info(f"Uploading {local_file} → s3://{parsed.bucket}/{s3_key}")
68+
bucket.upload_file(local_file, s3_key)
69+
5970
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
6071
s3 = get_s3_resource(kwargs)
6172
parsed_remote = parse_attachment_url(

model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from typing import Dict, NamedTuple, Optional
44

55
from huggingface_hub import list_repo_refs
6-
from huggingface_hub.utils._errors import RepositoryNotFoundError
6+
7+
try:
8+
from huggingface_hub.utils._errors import RepositoryNotFoundError
9+
except ImportError:
10+
from huggingface_hub.errors import RepositoryNotFoundError
711
from model_engine_server.core.loggers import logger_name, make_logger
812
from model_engine_server.domain.exceptions import ObjectNotFoundException
913
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway

model-engine/tests/unit/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
862862
if path in self.s3_bucket:
863863
return self.s3_bucket[path]
864864

865+
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
866+
pass
867+
865868
def get_model_weights_urls(self, owner: str, model_name: str):
866869
if (owner, model_name) in self.existing_models:
867870
return self.urls

0 commit comments

Comments
 (0)