Skip to content

Commit 4a90a55

Browse files
authored
fix cases of synchronous code blocking event loop (#755)
1 parent f436d25 commit 4a90a55

11 files changed

Lines changed: 54 additions & 23 deletions

model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class AsyncModelEndpointInferenceGateway(ABC):
1717
"""
1818

1919
@abstractmethod
20-
def create_task(
20+
async def create_task(
2121
self,
2222
topic: str,
2323
predict_request: EndpointPredictV1Request,

model-engine/model_engine_server/domain/gateways/task_queue_gateway.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# This is the abstract class defining putting and retrieving tasks into a queue.
2+
import asyncio
3+
import functools
24
from abc import ABC, abstractmethod
35
from typing import Any, Dict, List, Optional
46

@@ -32,6 +34,36 @@ def send_task(
3234
Returns: The unique identifier for the task.
3335
"""
3436

37+
async def send_task_async(
38+
self,
39+
task_name: str,
40+
queue_name: str,
41+
args: Optional[List[Any]] = None,
42+
kwargs: Optional[Dict[str, Any]] = None,
43+
expires: Optional[int] = None,
44+
) -> CreateAsyncTaskV1Response:
45+
"""
46+
Non-blocking version of send_task that runs in a thread executor
47+
to avoid blocking the event loop.
48+
49+
Note: This is a workaround for Celery's synchronous API. Ideally the
50+
gateway interface (including get_task) would be natively async, but
51+
Celery lacks first-class asyncio support, so we use run_in_executor
52+
as a pragmatic bridge.
53+
"""
54+
loop = asyncio.get_event_loop()
55+
return await loop.run_in_executor(
56+
None,
57+
functools.partial(
58+
self.send_task,
59+
task_name=task_name,
60+
queue_name=queue_name,
61+
args=args,
62+
kwargs=kwargs,
63+
expires=expires,
64+
),
65+
)
66+
3567
@abstractmethod
3668
def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
3769
"""

model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def execute(
6666
task_name = model_endpoint.record.current_model_bundle.celery_task_name()
6767

6868
inference_gateway = self.model_endpoint_service.get_async_model_endpoint_inference_gateway()
69-
return inference_gateway.create_task(
69+
return await inference_gateway.create_task(
7070
topic=model_endpoint.record.destination,
7171
predict_request=request,
7272
task_timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS,

model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class LiveAsyncModelEndpointInferenceGateway(AsyncModelEndpointInferenceGateway)
2323
def __init__(self, task_queue_gateway: TaskQueueGateway):
2424
self.task_queue_gateway = task_queue_gateway
2525

26-
def create_task(
26+
async def create_task(
2727
self,
2828
topic: str,
2929
predict_request: EndpointPredictV1Request,
@@ -35,7 +35,7 @@ def create_task(
3535
# key in some fields, and root overriding only reflects in the json() output.
3636
predict_args = json.loads(predict_request.json())
3737

38-
send_task_response = self.task_queue_gateway.send_task(
38+
send_task_response = await self.task_queue_gateway.send_task_async(
3939
task_name=task_name,
4040
queue_name=topic,
4141
args=[predict_args, datetime.now(), predict_request.return_pickled],

model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self.resource_gateway = resource_gateway
5252
self.task_queue_gateway = task_queue_gateway
5353

54-
def create_model_endpoint_infra(
54+
async def create_model_endpoint_infra(
5555
self,
5656
*,
5757
model_endpoint_record: ModelEndpointRecord,
@@ -105,7 +105,7 @@ def create_model_endpoint_infra(
105105
default_callback_url=default_callback_url,
106106
default_callback_auth=default_callback_auth,
107107
)
108-
response = self.task_queue_gateway.send_task(
108+
response = await self.task_queue_gateway.send_task_async(
109109
task_name=BUILD_TASK_NAME,
110110
queue_name=get_service_builder_queue(SERVICE_IDENTIFIER, SERVICE_BUILDER_QUEUE),
111111
# celery request is required to be JSON serializables
@@ -196,7 +196,6 @@ async def update_model_endpoint_infra(
196196
default_callback_url = endpoint_config.default_callback_url
197197
if default_callback_auth is None and endpoint_config is not None:
198198
default_callback_auth = endpoint_config.default_callback_auth
199-
200199
aws_role = infra_state.aws_role
201200
results_s3_bucket = infra_state.results_s3_bucket
202201

@@ -225,7 +224,7 @@ async def update_model_endpoint_infra(
225224
default_callback_url=default_callback_url,
226225
default_callback_auth=default_callback_auth,
227226
)
228-
response = self.task_queue_gateway.send_task(
227+
response = await self.task_queue_gateway.send_task_async(
229228
task_name=BUILD_TASK_NAME,
230229
queue_name=get_service_builder_queue(SERVICE_IDENTIFIER, SERVICE_BUILDER_QUEUE),
231230
kwargs=dict(build_endpoint_request_json=build_endpoint_request.dict()),

model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ModelEndpointInfraGateway(ABC):
1818
"""
1919

2020
@abstractmethod
21-
def create_model_endpoint_infra(
21+
async def create_model_endpoint_infra(
2222
self,
2323
*,
2424
model_endpoint_record: ModelEndpointRecord,

model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,10 @@ async def _read_or_submit_tasks(
267267
async def _submit_tasks(
268268
self, queue_name: str, input_path: str, task_name: str
269269
) -> List[BatchEndpointInProgressTask]:
270-
def _create_task(
270+
async def _create_task(
271271
predict_request: BatchEndpointInferencePrediction,
272272
) -> BatchEndpointInProgressTask:
273-
response = self.async_model_endpoint_inference_gateway.create_task(
273+
response = await self.async_model_endpoint_inference_gateway.create_task(
274274
topic=queue_name,
275275
predict_request=predict_request.request,
276276
task_timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS,
@@ -301,9 +301,8 @@ def _create_task(
301301
BatchEndpointInferencePrediction(request=request, reference_id=reference_id)
302302
)
303303

304-
executor = ThreadPoolExecutor()
305-
task_ids = list(executor.map(_create_task, inputs))
306-
return task_ids
304+
task_ids = await asyncio.gather(*[_create_task(inp) for inp in inputs])
305+
return list(task_ids)
307306

308307
def _poll_tasks(
309308
self,

model-engine/model_engine_server/infra/services/live_model_endpoint_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ async def create_model_endpoint(
187187
public_inference=public_inference,
188188
)
189189
)
190-
creation_task_id = self.model_endpoint_infra_gateway.create_model_endpoint_infra(
190+
creation_task_id = await self.model_endpoint_infra_gateway.create_model_endpoint_infra(
191191
model_endpoint_record=model_endpoint_record,
192192
min_workers=min_workers,
193193
max_workers=max_workers,

model-engine/tests/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def __init__(
10711071
def _get_deployment_name(user_id: str, model_endpoint_name: str) -> str:
10721072
return f"{user_id}-{model_endpoint_name}"
10731073

1074-
def create_model_endpoint_infra(
1074+
async def create_model_endpoint_infra(
10751075
self,
10761076
*,
10771077
model_endpoint_record: ModelEndpointRecord,
@@ -1638,7 +1638,7 @@ class FakeAsyncModelEndpointInferenceGateway(AsyncModelEndpointInferenceGateway)
16381638
def __init__(self):
16391639
self.tasks = []
16401640

1641-
def create_task(
1641+
async def create_task(
16421642
self,
16431643
topic: str,
16441644
predict_request: EndpointPredictV1Request,

model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def fake_live_async_model_inference_gateway(fake_task_queue_gateway):
1313

1414

1515
@pytest.mark.asyncio
16-
def test_task_create_get_url(
16+
async def test_task_create_get_url(
1717
fake_live_async_model_inference_gateway: LiveAsyncModelEndpointInferenceGateway,
1818
endpoint_predict_request_1,
1919
):
20-
create_response = fake_live_async_model_inference_gateway.create_task(
20+
create_response = await fake_live_async_model_inference_gateway.create_task(
2121
"test_topic", endpoint_predict_request_1[0], 60
2222
)
2323
task_id = create_response.task_id
@@ -41,11 +41,11 @@ def test_task_create_get_url(
4141

4242

4343
@pytest.mark.asyncio
44-
def test_task_create_get_args_callback(
44+
async def test_task_create_get_args_callback(
4545
fake_live_async_model_inference_gateway: LiveAsyncModelEndpointInferenceGateway,
4646
endpoint_predict_request_2,
4747
):
48-
create_response = fake_live_async_model_inference_gateway.create_task(
48+
create_response = await fake_live_async_model_inference_gateway.create_task(
4949
"test_topic", endpoint_predict_request_2[0], 60
5050
)
5151
task_id = create_response.task_id

0 commit comments

Comments
 (0)