Skip to content

Commit f436d25

Browse files
added gcp support in llm engine (#750)
1 parent 429cf01 commit f436d25

22 files changed

Lines changed: 1311 additions & 42 deletions

model-engine/model_engine_server/api/dependencies.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
CeleryTaskQueueGateway,
6565
DatadogMonitoringMetricsGateway,
6666
FakeMonitoringMetricsGateway,
67+
GCSFileStorageGateway,
68+
GCSFilesystemGateway,
69+
GCSLLMArtifactGateway,
6770
LiveAsyncModelEndpointInferenceGateway,
6871
LiveBatchJobOrchestrationGateway,
6972
LiveBatchJobProgressGateway,
@@ -100,6 +103,9 @@
100103
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
101104
QueueEndpointResourceDelegate,
102105
)
106+
from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import (
107+
RedisQueueEndpointResourceDelegate,
108+
)
103109
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
104110
SQSQueueEndpointResourceDelegate,
105111
)
@@ -115,6 +121,9 @@
115121
DbTriggerRepository,
116122
ECRDockerRepository,
117123
FakeDockerRepository,
124+
GARDockerRepository,
125+
GCSFileLLMFineTuneEventsRepository,
126+
GCSFileLLMFineTuneRepository,
118127
LiveTokenizerRepository,
119128
LLMFineTuneRepository,
120129
OnPremDockerRepository,
@@ -224,13 +233,18 @@ def _get_external_interfaces(
224233
read_only=read_only,
225234
)
226235

236+
redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool())
237+
227238
queue_delegate: QueueEndpointResourceDelegate
228239
if CIRCLECI:
229240
queue_delegate = FakeQueueEndpointResourceDelegate()
230241
elif infra_config().cloud_provider == "onprem":
231242
queue_delegate = OnPremQueueEndpointResourceDelegate()
232243
elif infra_config().cloud_provider == "azure":
233244
queue_delegate = ASBQueueEndpointResourceDelegate()
245+
elif infra_config().cloud_provider == "gcp":
246+
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
247+
queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client)
234248
else:
235249
queue_delegate = SQSQueueEndpointResourceDelegate(
236250
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
@@ -245,13 +259,13 @@ def _get_external_interfaces(
245259
elif infra_config().cloud_provider == "azure":
246260
inference_task_queue_gateway = servicebus_task_queue_gateway
247261
infra_task_queue_gateway = servicebus_task_queue_gateway
248-
elif infra_config().celery_broker_type_redis:
262+
elif infra_config().cloud_provider == "gcp" or infra_config().celery_broker_type_redis:
263+
# GCP uses Redis (Memorystore) for Celery broker
249264
inference_task_queue_gateway = redis_task_queue_gateway
250265
infra_task_queue_gateway = redis_task_queue_gateway
251266
else:
252267
inference_task_queue_gateway = sqs_task_queue_gateway
253268
infra_task_queue_gateway = sqs_task_queue_gateway
254-
redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool())
255269
inference_autoscaling_metrics_gateway = (
256270
ASBInferenceAutoscalingMetricsGateway()
257271
if infra_config().cloud_provider == "azure"
@@ -286,6 +300,9 @@ def _get_external_interfaces(
286300
if infra_config().cloud_provider == "azure":
287301
filesystem_gateway = ABSFilesystemGateway()
288302
llm_artifact_gateway = ABSLLMArtifactGateway()
303+
elif infra_config().cloud_provider == "gcp":
304+
filesystem_gateway = GCSFilesystemGateway()
305+
llm_artifact_gateway = GCSLLMArtifactGateway()
289306
else:
290307
# AWS uses S3, on-prem uses MinIO (S3-compatible)
291308
filesystem_gateway = S3FilesystemGateway()
@@ -337,6 +354,11 @@ def _get_external_interfaces(
337354
if infra_config().cloud_provider == "azure":
338355
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path)
339356
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
357+
elif infra_config().cloud_provider == "gcp":
358+
llm_fine_tune_repository = GCSFileLLMFineTuneRepository(
359+
file_path=file_path,
360+
)
361+
llm_fine_tune_events_repository = GCSFileLLMFineTuneEventsRepository()
340362
else:
341363
# AWS uses S3, on-prem uses MinIO (S3-compatible)
342364
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path)
@@ -354,6 +376,8 @@ def _get_external_interfaces(
354376
file_storage_gateway: FileStorageGateway
355377
if infra_config().cloud_provider == "azure":
356378
file_storage_gateway = ABSFileStorageGateway()
379+
elif infra_config().cloud_provider == "gcp":
380+
file_storage_gateway = GCSFileStorageGateway()
357381
else:
358382
# AWS uses S3, on-prem uses MinIO (S3-compatible)
359383
file_storage_gateway = S3FileStorageGateway()
@@ -365,6 +389,8 @@ def _get_external_interfaces(
365389
docker_repository = OnPremDockerRepository()
366390
elif infra_config().cloud_provider == "azure":
367391
docker_repository = ACRDockerRepository()
392+
elif infra_config().cloud_provider == "gcp":
393+
docker_repository = GARDockerRepository()
368394
else:
369395
docker_repository = ECRDockerRepository()
370396

@@ -417,11 +443,13 @@ async def get_external_interfaces():
417443
try:
418444
from plugins.dependencies import get_external_interfaces as get_custom_external_interfaces
419445

420-
yield get_custom_external_interfaces()
446+
ei = get_custom_external_interfaces()
421447
except ModuleNotFoundError:
422-
yield get_default_external_interfaces()
448+
ei = get_default_external_interfaces()
449+
try:
450+
yield ei
423451
finally:
424-
pass
452+
await ei.file_storage_gateway.close()
425453

426454

427455
async def get_external_interfaces_read_only():
@@ -430,11 +458,13 @@ async def get_external_interfaces_read_only():
430458
get_external_interfaces_read_only as get_custom_external_interfaces_read_only,
431459
)
432460

433-
yield get_custom_external_interfaces_read_only()
461+
ei = get_custom_external_interfaces_read_only()
434462
except ModuleNotFoundError:
435-
yield get_default_external_interfaces_read_only()
463+
ei = get_default_external_interfaces_read_only()
464+
try:
465+
yield ei
436466
finally:
437-
pass
467+
await ei.file_storage_gateway.close()
438468

439469

440470
def get_default_auth_repository() -> AuthenticationRepository:

model-engine/model_engine_server/core/celery/celery_autoscaler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def excluded_namespaces():
4343

4444

4545
ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
46+
GCP_MEMORYSTORE_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master"
4647
SQS_BROKER = "sqs-message-broker-master"
4748
SERVICEBUS_BROKER = "servicebus-message-broker-master"
4849

@@ -589,6 +590,8 @@ async def main():
589590

590591
BROKER_NAME_TO_CLASS = {
591592
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
593+
# GCP Memorystore also doesn't support CONFIG GET
594+
GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True),
592595
SQS_BROKER: SQSBroker(),
593596
SERVICEBUS_BROKER: ASBBroker(),
594597
}

model-engine/model_engine_server/domain/gateways/file_storage_gateway.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
9292
The content of the file, or None if it does not exist.
9393
"""
9494
pass
95+
96+
async def close(self) -> None:
97+
"""Release any resources held by this gateway. No-op by default."""
98+
pass

model-engine/model_engine_server/infra/gateways/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway
1111
from .fake_model_primitive_gateway import FakeModelPrimitiveGateway
1212
from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway
13+
from .gcs_file_storage_gateway import GCSFileStorageGateway
14+
from .gcs_filesystem_gateway import GCSFilesystemGateway
15+
from .gcs_llm_artifact_gateway import GCSLLMArtifactGateway
1316
from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway
1417
from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway
1518
from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway
@@ -37,6 +40,9 @@
3740
"DatadogMonitoringMetricsGateway",
3841
"FakeModelPrimitiveGateway",
3942
"FakeMonitoringMetricsGateway",
43+
"GCSFileStorageGateway",
44+
"GCSFilesystemGateway",
45+
"GCSLLMArtifactGateway",
4046
"LiveAsyncModelEndpointInferenceGateway",
4147
"LiveBatchJobOrchestrationGateway",
4248
"LiveBatchJobProgressGateway",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import asyncio
2+
import os
3+
from datetime import timedelta
4+
from typing import List, Optional
5+
6+
from gcloud.aio.storage import Storage
7+
from model_engine_server.core.config import infra_config
8+
from model_engine_server.domain.gateways.file_storage_gateway import (
9+
FileMetadata,
10+
FileStorageGateway,
11+
)
12+
from model_engine_server.infra.gateways.gcs_storage_client import get_gcs_sync_client, parse_gcs_uri
13+
14+
15+
def _get_gcs_key(owner: str, file_id: str) -> str:
16+
return os.path.join(owner, file_id)
17+
18+
19+
def _get_gcs_url(owner: str, file_id: str) -> str:
20+
return f"gs://{infra_config().s3_bucket}/{_get_gcs_key(owner, file_id)}"
21+
22+
23+
def _generate_signed_url_sync(uri: str, expiration: int = 3600) -> str:
24+
"""Generate a V4 signed URL synchronously (gcloud-aio-storage does not support this)."""
25+
bucket_name, blob_name = parse_gcs_uri(uri)
26+
client = get_gcs_sync_client()
27+
bucket = client.bucket(bucket_name)
28+
blob = bucket.blob(blob_name)
29+
return blob.generate_signed_url(
30+
version="v4",
31+
expiration=timedelta(seconds=expiration),
32+
method="GET",
33+
)
34+
35+
36+
class GCSFileStorageGateway(FileStorageGateway):
37+
"""
38+
Concrete implementation of a file storage gateway backed by GCS,
39+
using gcloud-aio-storage for async-native operations.
40+
"""
41+
42+
def __init__(self) -> None:
43+
self._storage = Storage()
44+
45+
async def close(self) -> None:
46+
await self._storage.close()
47+
48+
async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]:
49+
uri = _get_gcs_url(owner, file_id)
50+
return await asyncio.to_thread(_generate_signed_url_sync, uri)
51+
52+
async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]:
53+
bucket_name = infra_config().s3_bucket
54+
blob_name = _get_gcs_key(owner, file_id)
55+
try:
56+
metadata = await self._storage.download_metadata(bucket_name, blob_name)
57+
return FileMetadata(
58+
id=file_id,
59+
filename=file_id,
60+
size=int(metadata.get("size", 0)),
61+
owner=owner,
62+
updated_at=metadata.get("updated"),
63+
)
64+
except Exception:
65+
return None
66+
67+
async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
68+
bucket_name = infra_config().s3_bucket
69+
blob_name = _get_gcs_key(owner, file_id)
70+
try:
71+
content = await self._storage.download(bucket_name, blob_name)
72+
return content.decode("utf-8")
73+
except Exception:
74+
return None
75+
76+
async def upload_file(self, owner: str, filename: str, content: bytes) -> str:
77+
bucket_name = infra_config().s3_bucket
78+
blob_name = _get_gcs_key(owner, filename)
79+
await self._storage.upload(bucket_name, blob_name, content)
80+
return filename
81+
82+
async def delete_file(self, owner: str, file_id: str) -> bool:
83+
bucket_name = infra_config().s3_bucket
84+
blob_name = _get_gcs_key(owner, file_id)
85+
try:
86+
await self._storage.delete(bucket_name, blob_name)
87+
return True
88+
except Exception:
89+
return False
90+
91+
async def list_files(self, owner: str) -> List[FileMetadata]:
92+
bucket_name = infra_config().s3_bucket
93+
files: List[FileMetadata] = []
94+
params = {"prefix": owner}
95+
while True:
96+
response = await self._storage.list_objects(bucket_name, params=params)
97+
for item in response.get("items", []):
98+
blob_name = item.get("name", "")
99+
file_id = blob_name.replace(f"{owner}/", "", 1)
100+
files.append(
101+
FileMetadata(
102+
id=file_id,
103+
filename=file_id,
104+
size=int(item.get("size", 0)),
105+
owner=owner,
106+
updated_at=item.get("updated"),
107+
)
108+
)
109+
next_token = response.get("nextPageToken")
110+
if not next_token:
111+
break
112+
params = {"prefix": owner, "pageToken": next_token}
113+
return files
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import asyncio
2+
from datetime import timedelta
3+
from typing import IO
4+
5+
import smart_open
6+
from gcloud.aio.storage import Storage
7+
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway
8+
from model_engine_server.infra.gateways.gcs_storage_client import get_gcs_sync_client, parse_gcs_uri
9+
10+
11+
class GCSFilesystemGateway(FilesystemGateway):
12+
"""
13+
Concrete implementation for interacting with a filesystem backed by Google Cloud Storage.
14+
15+
Provides both sync methods (required by FilesystemGateway ABC) and async-native
16+
counterparts using gcloud-aio-storage for use in async contexts.
17+
"""
18+
19+
def open(self, uri: str, mode: str = "rt", **kwargs) -> IO:
20+
client = get_gcs_sync_client()
21+
transport_params = {"client": client}
22+
return smart_open.open(uri, mode, transport_params=transport_params)
23+
24+
def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
25+
bucket_name, blob_name = parse_gcs_uri(uri)
26+
client = get_gcs_sync_client()
27+
bucket = client.bucket(bucket_name)
28+
blob = bucket.blob(blob_name)
29+
return blob.generate_signed_url(
30+
version="v4",
31+
expiration=timedelta(seconds=expiration),
32+
method="GET",
33+
**kwargs,
34+
)
35+
36+
async def async_read(self, uri: str) -> bytes:
37+
"""Async-native download of blob content."""
38+
bucket_name, blob_name = parse_gcs_uri(uri)
39+
async with Storage() as storage:
40+
return await storage.download(bucket_name, blob_name)
41+
42+
async def async_write(self, uri: str, content: bytes) -> None:
43+
"""Async-native upload of blob content."""
44+
bucket_name, blob_name = parse_gcs_uri(uri)
45+
async with Storage() as storage:
46+
await storage.upload(bucket_name, blob_name, content)
47+
48+
async def async_generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
49+
"""Async wrapper for signed URL generation (offloaded to a thread)."""
50+
return await asyncio.to_thread(self.generate_signed_url, uri, expiration, **kwargs)

0 commit comments

Comments
 (0)