From 586a806224667871f34a4717964e3ce257f484fc Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Tue, 5 May 2026 20:10:27 -0700 Subject: [PATCH] feat: add plugin interface --- .github/hooks/pre-commit | 9 + .github/scripts/lintcommit.py | 6 +- .../examples-catalog.json | 11 + .../src/plugin/execution_with_plugin.py | 63 ++ .../template.yaml | 18 + .../test/plugin/test_plugin.py | 24 + .../execution.py | 88 +- .../lambda_service.py | 64 ++ .../operation/child.py | 10 + .../operation/step.py | 17 +- .../operation/wait_for_condition.py | 12 + .../plugin.py | 386 +++++++++ .../aws_durable_execution_sdk_python/state.py | 41 +- .../tests/e2e/checkpoint_response_int_test.py | 22 +- .../tests/e2e/execution_int_test.py | 14 +- .../e2e/map_with_concurrent_waits_int_test.py | 2 + .../tests/execution_test.py | 404 +++++++-- .../tests/logger_test.py | 5 + .../tests/plugin_test.py | 784 ++++++++++++++++++ .../tests/state_test.py | 515 +++++++++++- 20 files changed, 2356 insertions(+), 139 deletions(-) create mode 100755 .github/hooks/pre-commit create mode 100644 packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py create mode 100644 packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py create mode 100644 packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py create mode 100644 packages/aws-durable-execution-sdk-python/tests/plugin_test.py diff --git a/.github/hooks/pre-commit b/.github/hooks/pre-commit new file mode 100755 index 00000000..a9a1af02 --- /dev/null +++ b/.github/hooks/pre-commit @@ -0,0 +1,9 @@ +#!/bin/sh + +if hatch fmt --check; then + echo "Hatch fmt check passed!" +else + hatch fmt + echo "Error: hatch fmt modified your files. Please re-stage and commit again." + exit 1 +fi \ No newline at end of file diff --git a/.github/scripts/lintcommit.py b/.github/scripts/lintcommit.py index f24ab886..255ea0ec 100644 --- a/.github/scripts/lintcommit.py +++ b/.github/scripts/lintcommit.py @@ -164,7 +164,8 @@ def lint_range(git_range: str, *, skip_dirty_check: bool = False) -> LintResult: status = subprocess.run( ["git", "status", "--porcelain"], capture_output=True, - text=True, check=False, + text=True, + check=False, ) if status.stdout.strip(): return LintResult( @@ -178,7 +179,8 @@ def lint_range(git_range: str, *, skip_dirty_check: bool = False) -> LintResult: result = subprocess.run( ["git", "log", "--no-merges", git_range, "-z", "--format=%H%n%B"], capture_output=True, - text=True, check=False, + text=True, + check=False, ) if result.returncode != 0: return LintResult(git_error=result.stderr.strip()) diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index fb9ab785..fe8ccfd7 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -602,6 +602,17 @@ "ExecutionTimeout": 300 }, "path": "./src/parallel/parallel_with_named_branches.py" + }, + { + "name": "Plugin", + "description": "Test plugin", + "handler": "execution_with_plugin.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/plugin/execution_with_plugin.py" } ] } diff --git a/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py new file mode 100644 index 00000000..71594573 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py @@ -0,0 +1,63 @@ +"""Demonstrates handler execution without any durable operations.""" + +import logging +from typing import Any + +from aws_durable_execution_sdk_python import StepContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, +) + + +class MyPlugin(DurableExecutionPlugin): + logger = logging.getLogger("MyPlugin") + + def on_operation_start(self, info): + self.logger.info(f"Operation started: {info}") + + def on_operation_end(self, info): + self.logger.info(f"Operation ended: {info}") + + def on_invocation_start(self, info): + self.logger.info(f"Invocation started: {info}") + + def on_invocation_end(self, info): + self.logger.info(f"Invocation ended: {info}") + + def on_user_function_start(self, info) -> None: + self.logger.info(f"User function started: {info}") + + def on_user_function_end(self, info) -> None: + self.logger.info(f"User function ended: {info}") + + +@durable_step +def add_numbers(_step_context: StepContext, a: int, b: int) -> int: + return a + b + + +@durable_with_child_context +def add_numbers_in_child(child_context: DurableContext, a: int, b: int): + result: int = child_context.step( + add_numbers(a, b), + name="add-a-and-b", + ) + return result + + +@durable_execution(plugins=[MyPlugin()]) +def handler(_event: Any, context: DurableContext) -> int: + result: int = context.run_in_child_context( + add_numbers_in_child(6, 4), + name="add-6-and-4", + ) + return context.step( + add_numbers(result, 2), + name="add-result-to-2", + ) diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index 2854e729..bf91637f 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -977,6 +977,24 @@ "ExecutionTimeout": 300 } } + }, + "ExecutionWithPlugin": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "execution_with_plugin.handler", + "Description": "Test plugin", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py new file mode 100644 index 00000000..5e21ba6e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py @@ -0,0 +1,24 @@ +"""Tests for step example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus + +from src.plugin import execution_with_plugin +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=execution_with_plugin.handler, + lambda_function_name="Plugin", +) +def test_plugin(durable_runner): + """Test basic step example.""" + with durable_runner: + result = durable_runner.run(input="{}", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == 12 + + step_result = result.get_step("add-result-to-2") + assert deserialize_operation_payload(step_result.result) == 12 diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index df535b41..251f8ccd 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -6,7 +6,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext @@ -26,6 +25,12 @@ Operation, OperationType, OperationUpdate, + InvocationStatus, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, ) from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -149,62 +154,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - -@dataclass(frozen=True) -class DurableExecutionInvocationOutput: - """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. - - If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, - payload must be empty for SUCCEEDED/FAILED status. - """ - - status: InvocationStatus - result: str | None = None - error: ErrorObject | None = None - - @classmethod - def from_dict( - cls, data: MutableMapping[str, Any] - ) -> DurableExecutionInvocationOutput: - """Create an instance from a dictionary. - - Args: - data: Dictionary with camelCase keys matching the original structure - - Returns: - A DurableExecutionInvocationOutput instance - """ - status = InvocationStatus(data.get("Status")) - error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None - return cls(status=status, result=data.get("Result"), error=error) - - def to_dict(self) -> MutableMapping[str, Any]: - """Convert to a dictionary with the original field names. - - Returns: - Dictionary with the original camelCase keys - """ - result: MutableMapping[str, Any] = {"Status": self.status.value} - - if self.result is not None: - # large payloads return "", because checkpointed already - result["Result"] = self.result - if self.error: - result["Error"] = self.error.to_dict() - - return result - - @classmethod - def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: - """Create a succeeded invocation output.""" - return cls(status=InvocationStatus.SUCCEEDED, result=result) - - # endregion Invocation models @@ -212,14 +161,29 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableExecutionPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: + """ + Decorator to create a durable execution handler. + + Args: + func: The user function to decorate + boto3_client: Optional boto3 Lambda client to use + plugins: Optional list of plugins to use (EXPERIMENTAL: This + feature has known issues and this parameter may change or be removed.) + """ # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) logger.debug("Starting durable execution handler...") + plugin_executor = PluginExecutor(plugins) + + @plugin_executor.handle_durable_output def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput service_client: DurableServiceClient @@ -255,6 +219,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: operations={}, service_client=service_client, replay_status=ReplayStatus.NEW, + plugin_executor=plugin_executor, ) try: @@ -306,6 +271,13 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) as executor, contextlib.closing(execution_state) as execution_state, ): + # execute the plugins + plugin_executor.on_invocation_start( + execution_arn=invocation_input.durable_execution_arn, + lambda_context=context, + execution_start_time=execution_state.get_execution_operation().start_timestamp, + is_replaying=execution_state.is_replaying(), + ) # Thread 1: Run background checkpoint processing executor.submit(execution_state.checkpoint_batches_forever) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py index aa78e4e8..38c44556 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py @@ -105,6 +105,70 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + # Used internally only: the invocation failed and the backend will retry + RETRY = "RETRY" + + +@dataclass(frozen=True) +class DurableExecutionInvocationOutput: + """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. + + If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, + payload must be empty for SUCCEEDED/FAILED status. + """ + + status: InvocationStatus + result: str | None = None + error: ErrorObject | None = None + + @classmethod + def from_dict( + cls, data: MutableMapping[str, Any] + ) -> DurableExecutionInvocationOutput: + """Create an instance from a dictionary. + + Args: + data: Dictionary with camelCase keys matching the original structure + + Returns: + A DurableExecutionInvocationOutput instance + """ + status = InvocationStatus(data.get("Status")) + error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None + return cls(status=status, result=data.get("Result"), error=error) + + def to_dict(self) -> MutableMapping[str, Any]: + """Convert to a dictionary with the original field names. + + Returns: + Dictionary with the original camelCase keys + """ + result: MutableMapping[str, Any] = {"Status": self.status.value} + + if self.result is not None: + # large payloads return "", because checkpointed already + result["Result"] = self.result + if self.error: + result["Error"] = self.error.to_dict() + + return result + + @classmethod + def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: + """Create a succeeded invocation output.""" + return cls(status=InvocationStatus.SUCCEEDED, result=result) + + @classmethod + def create_retry(cls, error: ErrorObject) -> DurableExecutionInvocationOutput: + """Create a failed invocation output.""" + return cls(status=InvocationStatus.RETRY, error=error) + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py index ecaf0f9d..fbfbc29f 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py @@ -15,6 +15,7 @@ ErrorObject, OperationSubType, OperationUpdate, + OperationType, ) from aws_durable_execution_sdk_python.operation.base import ( CheckResult, @@ -154,6 +155,13 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: self.operation_identifier.operation_id, self.operation_identifier.name, ) + # todo: fix is_replay + start_info = self.state.on_user_function_start( + self.operation_identifier, + OperationType.CONTEXT, + self.sub_type, + is_replay=True, + ) try: raw_result: T = self.func() @@ -223,6 +231,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Must ensure the child context result is persisted before returning to the parent. # This guarantees the result is durable and child operations won't be re-executed on replay # (unless replay_children=True for large payloads). + self.state.on_user_function_end(start_info) self.state.create_checkpoint(operation_update=success_operation) logger.debug( @@ -246,6 +255,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Checkpoint child context FAIL with blocking (is_sync=True, default). # Must ensure the failure state is persisted before raising the exception. # This guarantees the error is durable and child operations won't be re-executed on replay. + self.state.on_user_function_end(start_info, error_object) self.state.create_checkpoint(operation_update=fail_operation) # InvocationError and its derivatives can be retried. diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py index 8a418fb3..957067d4 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py @@ -17,12 +17,14 @@ from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, OperationUpdate, + OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.base import ( CheckResult, OperationExecutor, ) +from aws_durable_execution_sdk_python.plugin import UserFunctionStartInfo from aws_durable_execution_sdk_python.retries import RetryDecision, RetryPresets from aws_durable_execution_sdk_python.serdes import deserialize, serialize from aws_durable_execution_sdk_python.suspend import ( @@ -142,7 +144,7 @@ def check_result_status(self) -> CheckResult[T]: ): # Step was previously interrupted in a prior invocation - handle retry msg: str = f"Step operation_id={self.operation_identifier.operation_id} name={self.operation_identifier.name} was previously interrupted" - self.retry_handler(StepInterruptedError(msg), checkpointed_result) + self.retry_handler(StepInterruptedError(msg), checkpointed_result, None) checkpointed_result.raise_callable_error() # Ready to execute if STARTED + AT_LEAST_ONCE @@ -217,6 +219,10 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: ) ) + start_info = self.state.on_user_function_start( + self.operation_identifier, OperationType.STEP, None, False, attempt + ) + try: # This is the actual code provided by the caller to execute durably inside the step raw_result: T = self.func(step_context) @@ -235,6 +241,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Checkpoint SUCCEED operation with blocking (is_sync=True, default). # Must ensure the success state is persisted before returning the result to the caller. # This guarantees the step result is durable and won't be lost if Lambda terminates. + self.state.on_user_function_end(start_info) self.state.create_checkpoint(operation_update=success_operation) logger.debug( @@ -260,7 +267,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: self.operation_identifier.name, ) - self.retry_handler(e, checkpointed_result) + self.retry_handler(e, checkpointed_result, start_info) # If we've failed to raise an exception from the retry_handler, then we are in a # weird state, and should crash terminate the execution msg = "retry handler should have raised an exception, but did not." @@ -270,12 +277,14 @@ def retry_handler( self, error: Exception, checkpointed_result: CheckpointedResult, + start_info: UserFunctionStartInfo | None, ): """Checkpoint and suspend for replay if retry required, otherwise raise error. Args: error: The exception that occurred during step execution checkpointed_result: The checkpoint data containing operation state + start_info: Information about the user function start Raises: SuspendExecution: If retry is scheduled @@ -333,6 +342,8 @@ def retry_handler( # Checkpoint RETRY operation with blocking (is_sync=True, default). # Must ensure retry state is persisted before suspending execution. # This guarantees the retry attempt count and next attempt timestamp are durable. + if start_info: + self.state.on_user_function_end(start_info, error_object) self.state.create_checkpoint(operation_update=retry_operation) suspend_with_optional_resume_delay( @@ -351,6 +362,8 @@ def retry_handler( # Checkpoint FAIL operation with blocking (is_sync=True, default). # Must ensure the failure state is persisted before raising the exception. # This guarantees the error is durable and the step won't be retried on replay. + if start_info: + self.state.on_user_function_end(start_info, error_object) self.state.create_checkpoint(operation_update=fail_operation) if isinstance(error, StepInterruptedError): diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 5c4f1c4c..e3a5a492 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -11,6 +11,8 @@ from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, OperationUpdate, + OperationType, + OperationSubType, ) from aws_durable_execution_sdk_python.logger import LogInfo from aws_durable_execution_sdk_python.operation.base import ( @@ -176,6 +178,13 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: if checkpointed_result.operation and checkpointed_result.operation.step_details: attempt = checkpointed_result.operation.step_details.attempt + 1 + start_info = self.state.on_user_function_start( + self.operation_identifier, + OperationType.STEP, + OperationSubType.WAIT_FOR_CONDITION, + False, + attempt, + ) try: # Execute the check function with the injected logger check_context = WaitForConditionCheckContext( @@ -218,6 +227,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Checkpoint SUCCEED operation with blocking (is_sync=True, default). # Must ensure the final state is persisted before returning to the caller. # This guarantees the condition result is durable and won't be re-evaluated on replay. + self.state.on_user_function_end(start_info) self.state.create_checkpoint(operation_update=success_operation) logger.debug( @@ -251,6 +261,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Checkpoint RETRY operation with blocking (is_sync=True, default). # Must ensure the current state and next attempt timestamp are persisted before suspending. # This guarantees the polling state is durable and will resume correctly on the next invocation. + self.state.on_user_function_end(start_info) # no ErrorObject self.state.create_checkpoint(operation_update=retry_operation) suspend_with_optional_resume_delay( @@ -274,6 +285,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: # Checkpoint FAIL operation with blocking (is_sync=True, default). # Must ensure the failure state is persisted before raising the exception. # This guarantees the error is durable and the condition won't be re-evaluated on replay. + self.state.on_user_function_end(start_info, ErrorObject.from_exception(e)) self.state.create_checkpoint(operation_update=fail_operation) raise diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 00000000..1ca11b15 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,386 @@ +import contextlib +import datetime +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Callable, MutableMapping + +from aws_durable_execution_sdk_python.identifier import OperationIdentifier +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, + OperationUpdate, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.types import LambdaContext + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class OperationInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None + name: str | None + parent_id: str | None + start_time: datetime.datetime | None + + +@dataclass(frozen=True) +class OperationStartInfo(OperationInfo): + pass + + +@dataclass(frozen=True) +class OperationEndInfo(OperationInfo): + status: OperationStatus + end_time: datetime.datetime | None + error: ErrorObject | None + + +@dataclass(frozen=True) +class UserFunctionStartInfo(OperationInfo): + is_replay: bool = False # True if user function is called to replay (CONTEXT) + attempt: int | None = None + + +@dataclass(frozen=True) +class UserFunctionEndInfo(OperationInfo): + is_replay: bool + attempt: int | None + succeeded: bool | None + end_time: datetime.datetime | None + error: ErrorObject | None + + @classmethod + def from_start_info( + cls, start_info: UserFunctionStartInfo, error: ErrorObject | None + ) -> "UserFunctionEndInfo": + return UserFunctionEndInfo( + operation_id=start_info.operation_id, + operation_type=start_info.operation_type, + sub_type=start_info.sub_type, + name=start_info.name, + parent_id=start_info.parent_id, + start_time=start_info.start_time, + is_replay=start_info.is_replay, + attempt=start_info.attempt, + succeeded=error is None, + end_time=datetime.datetime.now(datetime.UTC), + error=error, + ) + + +@dataclass(frozen=True) +class InvocationInfo: + request_id: str | None + execution_arn: str | None + start_time: datetime.datetime | None + is_replay: bool + + +@dataclass(frozen=True) +class InvocationStartInfo(InvocationInfo): + pass + + +@dataclass(frozen=True) +class InvocationEndInfo(InvocationInfo): + status: InvocationStatus + end_time: datetime.datetime | None + error: ErrorObject | None + + @classmethod + def from_durable_execution_invocation_output( + cls, + invocation_start_info: InvocationStartInfo, + output: "DurableExecutionInvocationOutput", + ): + return InvocationEndInfo( + request_id=invocation_start_info.request_id, + execution_arn=invocation_start_info.execution_arn, + start_time=invocation_start_info.start_time, + is_replay=invocation_start_info.is_replay, + status=output.status, + end_time=datetime.datetime.now(datetime.UTC), + error=output.error, + ) + + +class DurableExecutionPlugin: + """Base class for plugins. Override only the methods you need.""" + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + """Called when an invocation starts. This is called within the thread that runs user function handler. + + Args: + info: Information about the invocation. + """ + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + """Called when an invocation ends. This is called within the thread that runs user function handler. + + Args: + info: Information about the invocation. + """ + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + """ + Called when an operation checkpoints STARTED status. This is called NOT within the thread that runs operation. + + Args: + info: Information about the operation. + + """ + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + """ + Called when an operation checkpoints a terminal status. This is called NOT within the thread that runs operation. + + Args: + info: Information about the operation. + """ + pass + + def on_user_function_start(self, info: UserFunctionStartInfo) -> None: + """Called when an operation starts to execute user provided function. This is called within the thread that runs user provided function. + + Args: + info: Information about the operation attempt. + """ + pass + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + """Called when an operation finishes executing user provided function. This is called within the thread that runs user provided function. + + Args: + info: Information about the operation attempt. + """ + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +class PluginExecutor: + def __init__(self, plugins: list[DurableExecutionPlugin] | None): + self._plugins = plugins or [] + self._executor: ThreadPoolExecutor | None = None + self._invocation_status: InvocationStartInfo | None = None + + @contextlib.contextmanager + def run(self): + if self._plugins: + self._executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="plugin-executor", + ) + try: + yield + finally: + # Shut down the thread pool, waiting for pending tasks to complete. + if self._executor: + self._executor.shutdown(wait=True) + + @staticmethod + def _dispatch_plugin(plugin: DurableExecutionPlugin, info) -> None: + """Invoke the appropriate plugin callback. Runs inside the thread pool.""" + try: + match info: + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case UserFunctionStartInfo(): + plugin.on_user_function_start(info) + case UserFunctionEndInfo(): + plugin.on_user_function_end(info) + case _: + raise RuntimeError(f"Unknown info type: {type(info)}") + except Exception: + # log and ignore the exception + logger.exception("Plugin %s exception ignored", plugin.__class__.__name__) + + def execute_plugins(self, info, sync): + if not self._executor: + return + for plugin in self._plugins: + if sync: + # this is called synchronously, so plugins will be able to manipulate thread local objects + self._dispatch_plugin(plugin, info) + else: + # this is called asynchronously, so plugins cannot manipulate thread local objects + self._executor.submit(self._dispatch_plugin, plugin, info) + + def on_invocation_start( + self, + execution_arn: str, + is_replaying: bool, + execution_start_time: datetime.datetime | None, + lambda_context: LambdaContext | None, + ) -> None: + aws_request_id = lambda_context.aws_request_id if lambda_context else None + invocation_start_time = ( + datetime.datetime.now(datetime.UTC) + if is_replaying + else execution_start_time + ) + self._invocation_status = InvocationStartInfo( + execution_arn=execution_arn, + request_id=aws_request_id, + is_replay=is_replaying, + start_time=invocation_start_time, + ) + self.execute_plugins(self._invocation_status, sync=True) + + def on_invocation_end( + self, + output: "DurableExecutionInvocationOutput", + ) -> None: + if self._invocation_status is None: + # on_invocation_start not called, skip + return + + invocation_end_info = ( + InvocationEndInfo.from_durable_execution_invocation_output( + self._invocation_status, output + ) + ) + self.execute_plugins(invocation_end_info, sync=True) + + def on_user_function_start( + self, + operation_identifier: OperationIdentifier, + operation_type: OperationType, + sub_type: OperationSubType | None, + is_replay: bool, + attempt: int | None = None, + ) -> UserFunctionStartInfo: + """Execute any registered plugins for the operation when its user function starts to execute.""" + start_info = UserFunctionStartInfo( + operation_id=operation_identifier.operation_id, + operation_type=operation_type, + sub_type=sub_type, + name=operation_identifier.name, + parent_id=operation_identifier.parent_id, + start_time=datetime.datetime.now(datetime.UTC), + is_replay=is_replay, + attempt=attempt, + ) + self.execute_plugins(start_info, sync=True) + return start_info + + def on_user_function_end(self, start_info: UserFunctionStartInfo, error) -> None: + """Execute any registered plugins for the operation when its user function finishes execution.""" + self.execute_plugins( + UserFunctionEndInfo.from_start_info(start_info, error), sync=True + ) + + def on_operation_action(self, update: OperationUpdate): + """Execute any registered plugins for a given operation when an update is checkpointed + + Args: + update: the operation update that is checkpointed + """ + if update.action is OperationAction.START: + # we handle only START action here because on_operation_update may not be able to see a STARTED update + # when START is checkpointed in batch with terminal status updates. + self.execute_plugins( + OperationStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_time=datetime.datetime.now(datetime.UTC), + ), + sync=False, + ) + + def on_operation_update(self, operation: Operation | None): + """Execute any registered plugins for a given operation when it receives an update + + Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be + checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED). + + Note: the operation may not be up-to-date if the checkpoint is called asynchronously. + + Args: + operation: the operation is just checkpointed + """ + if operation and self._is_terminal_status(operation.status): + self.execute_plugins( + OperationEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_time=operation.start_timestamp, + end_time=operation.end_timestamp, + status=operation.status, + error=self._extract_error(operation), + ), + sync=False, + ) + + @staticmethod + def _extract_error(operation: Operation): + if operation.step_details and operation.step_details.error: + return operation.step_details.error + if operation.callback_details and operation.callback_details.error: + return operation.callback_details.error + if operation.chained_invoke_details and operation.chained_invoke_details.error: + return operation.chained_invoke_details.error + if operation.context_details and operation.context_details.error: + return operation.context_details.error + return None + + @staticmethod + def _is_terminal_status(status): + return status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] + + @property + def handle_durable_output(self): + def decorator(func: Callable[[Any, LambdaContext], MutableMapping[str, Any]]): + @functools.wraps(func) + def wrapper(event: Any, context: LambdaContext): + with self.run(): + try: + output = func(event, context) + + self.on_invocation_end( + output=DurableExecutionInvocationOutput.from_dict(output), + ) + return output + except Exception as e: + self.on_invocation_end( + output=DurableExecutionInvocationOutput.create_retry( + ErrorObject.from_exception(e) + ), + ) + raise + + return wrapper + + return decorator diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 83175503..c7ffad08 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -19,6 +19,7 @@ GetExecutionStateError, OrphanedChildException, ) +from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, DurableServiceClient, @@ -29,6 +30,11 @@ OperationType, OperationUpdate, StateOutput, + OperationSubType, +) +from aws_durable_execution_sdk_python.plugin import ( + PluginExecutor, + UserFunctionStartInfo, ) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock @@ -236,6 +242,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -243,6 +250,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugin_executor: PluginExecutor = plugin_executor self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -274,7 +282,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -283,6 +291,8 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API Raises: GetExecutionStateError: If the API call fails. The error is logged @@ -315,6 +325,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -689,12 +700,18 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for update in updates: + self._plugin_executor.on_operation_action(update) + + for operation in updated_operations: + self._plugin_executor.on_operation_update(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: @@ -903,3 +920,23 @@ def _calculate_operation_size(queued_op: QueuedOperation) -> int: def close(self): self.stop_checkpointing() + + def on_user_function_start( + self, + operation_identifier: OperationIdentifier, + operation_type: OperationType, + sub_type: OperationSubType | None, + is_replay: bool, + attempt: int | None = None, + ) -> UserFunctionStartInfo: + return self._plugin_executor.on_user_function_start( + operation_identifier, operation_type, sub_type, is_replay, attempt + ) + + def on_user_function_end( + self, start_info: UserFunctionStartInfo, error: ErrorObject | None = None + ): + self._plugin_executor.on_user_function_end(start_info, error) + + def on_operation_update(self, operation: Operation | None): + self._plugin_executor.on_operation_update(operation) diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py index c0fd0f50..de168afc 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py @@ -28,7 +28,7 @@ ) if TYPE_CHECKING: - from aws_durable_execution_sdk_python.types import StepContext + from aws_durable_execution_sdk_python.types import StepContext, LambdaContext def create_mock_checkpoint_with_operations(): @@ -101,7 +101,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -164,7 +164,7 @@ def my_handler(event, context: DurableContext) -> list[str]: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -220,7 +220,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -279,7 +279,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -388,7 +388,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -440,7 +440,7 @@ def my_handler(event, context: DurableContext): mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -499,7 +499,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -598,7 +598,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -665,7 +665,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -730,7 +730,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py index 5a884bff..ed774632 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py @@ -135,7 +135,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -221,7 +221,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -262,7 +262,7 @@ def mock_checkpoint( 123, "str", extra={ - "executionArn": "test-arn", + "executionArn": "test-arn/execution-1", "operationName": "mystep", "attempt": 1, "operationId": operation_id, @@ -308,7 +308,7 @@ def my_handler(event, context): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -409,7 +409,7 @@ def mock_checkpoint_failure( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -463,7 +463,7 @@ def my_handler(event: Any, context: DurableContext): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -560,7 +560,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py index 8ad812e4..62ad7c2b 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py @@ -42,6 +42,7 @@ OperationUpdate, OperationType, ) +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, ExecutionState, @@ -68,6 +69,7 @@ def _make_state( operations={}, service_client=mock_client, batcher_config=config, + plugin_executor=PluginExecutor([]), ) diff --git a/packages/aws-durable-execution-sdk-python/tests/execution_test.py b/packages/aws-durable-execution-sdk-python/tests/execution_test.py index db13b5a9..390982b3 100644 --- a/packages/aws-durable-execution-sdk-python/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/execution_test.py @@ -23,7 +23,6 @@ from aws_durable_execution_sdk_python.execution import ( DurableExecutionInvocationInput, DurableExecutionInvocationInputWithClient, - DurableExecutionInvocationOutput, InitialExecutionState, InvocationStatus, durable_execution, @@ -46,7 +45,9 @@ StateOutput, StepDetails, WaitDetails, + DurableExecutionInvocationOutput, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin LARGE_RESULT = "large_success" * 1024 * 1024 @@ -56,7 +57,7 @@ def test_durable_execution_invocation_input_from_dict(): """Test that DurableExecutionInvocationInput.from_dict works correctly""" input_dict = { - "DurableExecutionArn": "9692ca80-399d-4f52-8d0a-41acc9cd0492", + "DurableExecutionArn": "9692ca80-399d-4f52-8d0a-41acc9cd0492/9692ca80-399d-4f52-8d0a-41acc9cd0492", "CheckpointToken": "9692ca80-399d-4f52-8d0a-41acc9cd0492", "InitialExecutionState": { "Operations": [ @@ -76,7 +77,10 @@ def test_durable_execution_invocation_input_from_dict(): result = DurableExecutionInvocationInput.from_dict(input_dict) - assert result.durable_execution_arn == "9692ca80-399d-4f52-8d0a-41acc9cd0492" + assert ( + result.durable_execution_arn + == "9692ca80-399d-4f52-8d0a-41acc9cd0492/9692ca80-399d-4f52-8d0a-41acc9cd0492" + ) assert result.checkpoint_token == "9692ca80-399d-4f52-8d0a-41acc9cd0492" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 1 @@ -167,14 +171,14 @@ def test_durable_execution_invocation_input_to_dict(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -186,14 +190,14 @@ def test_durable_execution_invocation_input_to_dict_not_local(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -207,7 +211,7 @@ def test_durable_execution_invocation_input_with_client_inheritance(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -216,7 +220,7 @@ def test_durable_execution_invocation_input_with_client_inheritance(): # Should inherit to_dict from parent class result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -231,7 +235,7 @@ def test_durable_execution_invocation_input_with_client_from_parent(): initial_state = InitialExecutionState(operations=[], next_marker="") parent_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -360,7 +364,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event with LocalRunner=False event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -412,7 +416,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event with LocalRunner=False event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -469,7 +473,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -516,7 +520,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -571,7 +575,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -617,7 +621,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -664,7 +668,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -702,7 +706,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -748,7 +752,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event dict instead of DurableExecutionInvocationInputWithClient event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -796,7 +800,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -835,7 +839,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -917,7 +921,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -957,7 +961,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1007,7 +1011,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1056,7 +1060,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1104,7 +1108,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1154,7 +1158,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1198,7 +1202,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1242,7 +1246,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1288,7 +1292,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1334,7 +1338,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1381,7 +1385,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1447,7 +1451,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1537,7 +1541,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1620,7 +1624,7 @@ def test_handler(event: Any, context: DurableContext) -> str: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1690,7 +1694,7 @@ def test_handler(event: Any, context: DurableContext) -> str: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1745,7 +1749,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1805,7 +1809,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1862,7 +1866,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1907,7 +1911,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: return {"result": "success"} event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2204,14 +2208,14 @@ def test_durable_execution_invocation_input_to_json_dict_minimal(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_json_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_json_dict(), } @@ -2238,7 +2242,7 @@ def test_durable_execution_invocation_input_to_json_dict_with_timestamps(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -2252,7 +2256,7 @@ def test_durable_execution_invocation_input_to_json_dict_with_timestamps(): assert operation_result["StartTimestamp"] == expected_start_ms assert operation_result["EndTimestamp"] == expected_end_ms - assert result["DurableExecutionArn"] == "arn:test:execution" + assert result["DurableExecutionArn"] == "arn:test:execution/exec1" assert result["CheckpointToken"] == "token123" @@ -2261,14 +2265,14 @@ def test_durable_execution_invocation_input_to_json_dict_empty_operations(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_json_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": {"Operations": [], "NextMarker": ""}, } @@ -2279,7 +2283,7 @@ def test_durable_execution_invocation_input_to_json_dict_empty_operations(): def test_durable_execution_invocation_input_from_json_dict_minimal(): """Test DurableExecutionInvocationInput.from_json_dict with minimal data.""" data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2295,7 +2299,7 @@ def test_durable_execution_invocation_input_from_json_dict_minimal(): result = DurableExecutionInvocationInput.from_json_dict(data) - assert result.durable_execution_arn == "arn:test:execution" + assert result.durable_execution_arn == "arn:test:execution/exec1" assert result.checkpoint_token == "token123" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 1 @@ -2309,7 +2313,7 @@ def test_durable_execution_invocation_input_from_json_dict_with_timestamps(): end_ms = 1672578000000 # 2023-01-01 13:00:00 UTC data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2340,13 +2344,13 @@ def test_durable_execution_invocation_input_from_json_dict_with_timestamps(): def test_durable_execution_invocation_input_from_json_dict_empty_initial_state(): """Test DurableExecutionInvocationInput.from_json_dict handles missing InitialExecutionState.""" data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", } result = DurableExecutionInvocationInput.from_json_dict(data) - assert result.durable_execution_arn == "arn:test:execution" + assert result.durable_execution_arn == "arn:test:execution/exec1" assert result.checkpoint_token == "token123" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 0 @@ -2486,7 +2490,7 @@ def test_durable_execution_invocation_input_json_dict_preserves_non_timestamp_fi ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -2504,7 +2508,7 @@ def test_durable_execution_invocation_input_json_dict_preserves_non_timestamp_fi assert operation_result["CallbackDetails"]["CallbackId"] == "cb123" assert operation_result["CallbackDetails"]["Result"] == "callback_result" - assert result["DurableExecutionArn"] == "arn:test:execution" + assert result["DurableExecutionArn"] == "arn:test:execution/exec1" assert result["CheckpointToken"] == "token123" assert result["InitialExecutionState"]["NextMarker"] == "marker123" @@ -2666,7 +2670,7 @@ def _make_invocation_input(mock_client, next_marker=""): execution_details=ExecutionDetails(input_payload="{}"), ) return DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=InitialExecutionState( operations=[operation], next_marker=next_marker @@ -2711,7 +2715,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert result["Status"] == InvocationStatus.SUCCEEDED.value assert json.loads(result["Result"]) == {"is_replaying": True} mock_client.get_execution_state.assert_called_once_with( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", next_marker="page2", ) @@ -2827,3 +2831,293 @@ def test_handler(event: Any, context: DurableContext) -> dict: _make_invocation_input(mock_client, next_marker="next-page-marker"), _make_lambda_context(), ) + + +# region Plugin Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append(f"execution_end:{info.status.value}") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append(f"invocation_end:{info.status.value}") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("plugin boom") + + def on_execution_end(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_start(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("plugin boom") + + +def test_durable_execution_with_plugins_success(): + """Test that plugins receive invocation start/end and execution end on success.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # ExecutionStartInfo dispatches to on_invocation_start in the match block + assert "invocation_start" in plugin.calls + assert "invocation_end:SUCCEEDED" in plugin.calls + + +def test_durable_execution_with_plugins_failure(): + """Test that plugins receive invocation end and execution end on user error.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "user error" + raise ValueError(msg) + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.FAILED.value + assert "invocation_start" in plugin.calls + assert "invocation_end:FAILED" in plugin.calls + + +def test_durable_execution_with_plugins_pending(): + """Test that plugins receive invocation end with PENDING status on suspend.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + raise SuspendExecution("test") + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.PENDING.value + assert "invocation_start" in plugin.calls + assert "invocation_end:PENDING" in plugin.calls + # Execution end should NOT be fired for PENDING + execution_end_calls = [c for c in plugin.calls if c.startswith("execution_end")] + assert len(execution_end_calls) == 0 + + +def test_durable_execution_with_plugins_retryable_error(): + """Test that plugins receive invocation end with RETRY status on retryable error.""" + mock_client = Mock(spec=DurableServiceClient) + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "Retriable error" + raise InvocationError(msg) + + with pytest.raises(InvocationError): + test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert "invocation_start" in plugin.calls + assert "invocation_end:RETRY" in plugin.calls + + +def test_durable_execution_with_multiple_plugins(): + """Test that multiple plugins all receive callbacks.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin1 = _RecordingPlugin() + plugin2 = _RecordingPlugin() + + @durable_execution(plugins=[plugin1, plugin2]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin1.calls + assert "invocation_start" in plugin2.calls + assert "invocation_end:SUCCEEDED" in plugin1.calls + assert "invocation_end:SUCCEEDED" in plugin2.calls + + +def test_durable_execution_with_failing_plugin_does_not_break_execution(): + """Test that a failing plugin does not prevent the handler from completing.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + failing_plugin = _FailingPlugin() + recording_plugin = _RecordingPlugin() + + @durable_execution(plugins=[failing_plugin, recording_plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + # Execution should still succeed despite the failing plugin + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # The recording plugin should still have been called + assert "invocation_start" in recording_plugin.calls + assert "invocation_end:SUCCEEDED" in recording_plugin.calls + + +def test_durable_execution_with_no_plugins(): + """Test that passing no plugins (None) works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=None) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_with_empty_plugins_list(): + """Test that passing an empty plugins list works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=[]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_decorator_with_plugins_and_boto3_client(): + """Test that plugins parameter works alongside boto3_client parameter.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + # When using DurableExecutionInvocationInputWithClient, boto3_client is ignored + # but we verify the decorator accepts both parameters + @durable_execution(boto3_client=None, plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin.calls + + +# endregion Plugin Integration Tests diff --git a/packages/aws-durable-execution-sdk-python/tests/logger_test.py b/packages/aws-durable-execution-sdk-python/tests/logger_test.py index b6017fa6..1966e276 100644 --- a/packages/aws-durable-execution-sdk-python/tests/logger_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/logger_test.py @@ -11,6 +11,7 @@ OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -83,6 +84,7 @@ def exception( initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), ) @@ -227,6 +229,7 @@ def test_logger_with_log_info(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor([]), ) new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) @@ -377,6 +380,7 @@ def test_logger_replay_no_logging(): operations={"op1": operation}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) mock_logger = Mock() @@ -404,6 +408,7 @@ def test_logger_replay_then_new_logging(): operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() diff --git a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py new file mode 100644 index 00000000..e9d96fee --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py @@ -0,0 +1,784 @@ +import datetime +import logging +import unittest +from unittest.mock import MagicMock + +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + InvocationStatus, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + PluginExecutor, + UserFunctionStartInfo, + UserFunctionEndInfo, +) + + +# region Dataclass Tests + +ERROR = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) +START_TS = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) +END_TS = datetime.datetime(2025, 1, 2, tzinfo=datetime.UTC) +LAMBDA_CTX = MagicMock() +LAMBDA_CTX.aws_request_id = "req-1" + +OPERATION_START_INFO = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_time=START_TS, +) +OPERATION_END_INFO = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="my-op", + parent_id="parent-1", + start_time=START_TS, + status=OperationStatus.FAILED, + end_time=END_TS, + error=ERROR, +) + +INVOCATION_START_INFO = InvocationStartInfo( + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_time=START_TS, + is_replay=True, +) +INVOCATION_END_INFO = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_time=START_TS, + status=InvocationStatus.FAILED, + error=ERROR, + is_replay=False, + end_time=END_TS, +) + +USER_FUNCTION_START_INFO = UserFunctionStartInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="func", + parent_id="parent-1", + start_time=START_TS, +) + +USER_FUNCTION_END_INFO = UserFunctionEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="func", + parent_id="parent-1", + start_time=START_TS, + is_replay=False, + attempt=1, + succeeded=False, + end_time=END_TS, + error=ERROR, +) + + +class TestDataClasses(unittest.TestCase): + def test_operation_start_info(self): + self.assertEqual(OPERATION_START_INFO.sub_type, OperationSubType.CALLBACK) + self.assertEqual(OPERATION_START_INFO.name, "my-op") + self.assertEqual(OPERATION_START_INFO.parent_id, "parent-1") + self.assertEqual(OPERATION_START_INFO.start_time, START_TS) + + def test_operation_end_info(self): + self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) + self.assertEqual(OPERATION_END_INFO.end_time, END_TS) + self.assertEqual(OPERATION_END_INFO.error, ERROR) + self.assertEqual(OPERATION_END_INFO.operation_type, OperationType.STEP) + self.assertEqual(OPERATION_END_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(OPERATION_END_INFO.name, "my-op") + self.assertEqual(OPERATION_END_INFO.parent_id, "parent-1") + self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") + self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) + self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") + + def test_invocation_start_info(self): + self.assertEqual(INVOCATION_START_INFO.request_id, "req-1") + self.assertEqual( + INVOCATION_START_INFO.execution_arn, + "arn:aws:lambda:us-east-1:123:durable:abc", + ) + self.assertEqual(INVOCATION_START_INFO.start_time, START_TS) + self.assertTrue(INVOCATION_START_INFO.is_replay) + + def test_invocation_end_info(self): + self.assertEqual(INVOCATION_END_INFO.request_id, "req-1") + self.assertEqual(INVOCATION_END_INFO.execution_arn, "arn:test") + self.assertEqual(INVOCATION_END_INFO.start_time, START_TS) + self.assertFalse(INVOCATION_END_INFO.is_replay) + self.assertEqual(INVOCATION_END_INFO.status, InvocationStatus.FAILED) + self.assertEqual(INVOCATION_END_INFO.error.message, "boom") + self.assertEqual(INVOCATION_END_INFO.end_time, END_TS) + + def test_user_function_start_info(self): + self.assertEqual(USER_FUNCTION_START_INFO.operation_id, "op-1") + self.assertEqual(USER_FUNCTION_START_INFO.operation_type, OperationType.STEP) + self.assertEqual(USER_FUNCTION_START_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(USER_FUNCTION_START_INFO.name, "func") + self.assertEqual(USER_FUNCTION_START_INFO.parent_id, "parent-1") + self.assertEqual(USER_FUNCTION_START_INFO.start_time, START_TS) + + def test_user_function_end_info(self): + self.assertEqual(USER_FUNCTION_END_INFO.operation_id, "op-1") + self.assertEqual(USER_FUNCTION_END_INFO.operation_type, OperationType.STEP) + self.assertEqual(USER_FUNCTION_END_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(USER_FUNCTION_END_INFO.name, "func") + self.assertEqual(USER_FUNCTION_END_INFO.parent_id, "parent-1") + self.assertEqual(USER_FUNCTION_END_INFO.start_time, START_TS) + self.assertFalse(USER_FUNCTION_END_INFO.is_replay) + self.assertEqual(USER_FUNCTION_END_INFO.attempt, 1) + self.assertFalse(USER_FUNCTION_END_INFO.succeeded) + self.assertEqual(USER_FUNCTION_END_INFO.end_time, END_TS) + self.assertEqual(USER_FUNCTION_END_INFO.error.message, "boom") + + +# endregion Dataclass Tests + + +# region DurableExecutionPlugin Tests +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + self.assertIsNone(plugin.on_invocation_start(INVOCATION_START_INFO)) + self.assertIsNone(plugin.on_invocation_end(INVOCATION_END_INFO)) + self.assertIsNone(plugin.on_operation_start(OPERATION_START_INFO)) + self.assertIsNone(plugin.on_operation_end(OPERATION_END_INFO)) + self.assertIsNone(plugin.on_user_function_start(USER_FUNCTION_START_INFO)) + self.assertIsNone(plugin.on_user_function_end(USER_FUNCTION_END_INFO)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + + plugin.on_invocation_start(INVOCATION_START_INFO) + plugin.on_operation_start(OPERATION_START_INFO) + + self.assertEqual( + ["invocation_start:req-1", "operation_start:op-2"], plugin.calls + ) + + +# endregion DurableExecutionPlugin Tests + + +# region PluginExecutor Tests + + +class TestPluginExecutorInit(unittest.TestCase): + def test_init_with_none(self): + executor = PluginExecutor(plugins=None) + self.assertEqual(executor._plugins, []) + + def test_init_with_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertEqual(executor._plugins, []) + + def test_init_with_plugins(self): + p1 = _NoOpPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + self.assertEqual(len(executor._plugins), 2) + + +class TestPluginExecutor(unittest.TestCase): + def test_no_thread_pool_when_plugins_is_none(self): + """Tests that PluginExecutor does not create a thread pool when plugins is empty.""" + executor = PluginExecutor(plugins=None) + self.assertIsNone(executor._executor) + + def test_no_thread_pool_when_plugins_is_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertIsNone(executor._executor) + + def test_thread_pool_created_when_plugins_provided(self): + executor = PluginExecutor(plugins=[_NoOpPlugin()]) + with executor.run(): + self.assertIsNotNone(executor._executor) + + def test_start_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + with executor.run(): + pass + + def test_on_invocation_start_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + + def test_on_invocation_end_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + # Should not raise + executor.on_invocation_end( + output=output, + ) + + def test_on_operation_action_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = None + + # Should not raise + executor.on_operation_action(update) + + def test_on_operation_update_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = None + op.start_time = START_TS + op.end_time = END_TS + op.status = OperationStatus.SUCCEEDED + op.step_details = MagicMock() + op.step_details.attempt = 1 + op.step_details.error = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + # Should not raise + executor.on_operation_update(op) + + +class TestPluginExecutorExecutePlugins(unittest.TestCase): + """Tests for the execute_plugins dispatch method.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_dispatch_invocation_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(INVOCATION_START_INFO, sync=True) + self.assertIn("invocation_start:req-1", self.plugin.calls) + + def test_dispatch_invocation_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(INVOCATION_END_INFO, sync=True) + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_dispatch_operation_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(OPERATION_END_INFO, sync=False) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_dispatch_operation_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(OPERATION_START_INFO, sync=False) + self.assertIn("operation_start:op-2", self.plugin.calls) + + def test_dispatch_user_function_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(USER_FUNCTION_START_INFO, sync=True) + self.assertIn("user_function_start:op-1", self.plugin.calls) + + def test_dispatch_user_function_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(USER_FUNCTION_END_INFO, sync=True) + self.assertIn("user_function_end:op-1", self.plugin.calls) + + def test_dispatch_unknown_type_logs_exception(self): + """Unknown info types should be caught and logged.""" + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with self.executor.run(): + self.executor.execute_plugins("not a valid info type", sync=True) + + def test_plugin_exception_is_swallowed(self): + """If a plugin raises, the exception is logged and execution continues.""" + failing_plugin = _FailingPlugin() + tracking_plugin = _TrackingPlugin() + executor = PluginExecutor(plugins=[failing_plugin, tracking_plugin]) + + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with executor.run(): + executor.execute_plugins(OPERATION_START_INFO, sync=True) + + # The second plugin should still have been called + self.assertIn("operation_start:op-2", tracking_plugin.calls) + + def test_multiple_plugins_all_called(self): + p1 = _TrackingPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + + with executor.run(): + executor.execute_plugins(OPERATION_START_INFO, sync=True) + + self.assertIn("operation_start:op-2", p1.calls) + self.assertIn("operation_start:op-2", p2.calls) + + +class TestPluginExecutorOnInvocationStart(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_start.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_operation(self, start_time=None): + op = MagicMock() + op.start_time = start_time or self.ts + return op + + def test_first_invocation_fires_invocation_start(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + + self.assertEqual("arn:exec", self.executor._invocation_status.execution_arn) + self.assertEqual( + LAMBDA_CTX.aws_request_id, self.executor._invocation_status.request_id + ) + self.assertEqual(START_TS, self.executor._invocation_status.start_time) + self.assertFalse(self.executor._invocation_status.is_replay) + + # ExecutionStartInfo dispatches to on_invocation_start in match + # InvocationStartInfo dispatches to on_invocation_start in match + # So we expect two invocation_start calls + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + self.assertEqual(1, len(invocation_calls)) + + def test_replay_invocation_fires_invocation_start(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=True, + ) + + # Only InvocationStartInfo should be dispatched (not ExecutionStartInfo) + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + self.assertEqual(1, len(invocation_calls)) + + def test_none_context_uses_none_request_id(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=None, + execution_start_time=START_TS, + is_replaying=False, + ) + + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + # Both ExecutionStartInfo and InvocationStartInfo dispatched + self.assertEqual(len(invocation_calls), 1) + # request_id should be None + self.assertIn("invocation_start:None", self.plugin.calls) + + +class TestPluginExecutorOnInvocationEnd(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_end.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_operation(self, start_ts=None, end_ts=None): + op = MagicMock() + op.start_time = start_ts or self.ts + op.end_time = end_ts + return op + + def test_succeeded_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_failed_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, result=None, error=ERROR + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_pending_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.PENDING, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + +class TestPluginExecutorOnOperationAction(unittest.TestCase): + """Tests for PluginExecutor.on_operation_action.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_start_action_fires_operation_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + with self.executor.run(): + self.executor.on_operation_action(update) + + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_non_start_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.SUCCEED + update.operation_id = "op-1" + + self.executor.on_operation_action(update) + + self.assertEqual(self.plugin.calls, []) + + def test_fail_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.FAIL + update.operation_id = "op-1" + + self.executor.on_operation_action(update) + + self.assertEqual(self.plugin.calls, []) + + +class TestPluginExecutorOnOperationUpdate(unittest.TestCase): + """Tests for PluginExecutor.on_operation_update.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def _make_operation( + self, + status=OperationStatus.SUCCEEDED, + step_details=None, + callback_details=None, + chained_invoke_details=None, + context_details=None, + ): + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = "parent-1" + op.start_time = START_TS + op.end_time = END_TS + op.status = status + op.step_details = step_details + op.callback_details = callback_details + op.chained_invoke_details = chained_invoke_details + op.context_details = context_details + return op + + def test_terminal_status_without_step_details_fires_operation_only(self): + op = self._make_operation(status=OperationStatus.FAILED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_non_terminal_status_without_step_details_fires_nothing(self): + op = self._make_operation(status=OperationStatus.STARTED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_ready_status_fires_nothing(self): + op = self._make_operation(status=OperationStatus.READY, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_timed_out_is_terminal(self): + op = self._make_operation(status=OperationStatus.TIMED_OUT, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_cancelled_is_terminal(self): + op = self._make_operation(status=OperationStatus.CANCELLED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_stopped_is_terminal(self): + op = self._make_operation(status=OperationStatus.STOPPED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + +class TestPluginExecutorExtractError(unittest.TestCase): + """Tests for PluginExecutor._extract_error static method.""" + + def test_extract_error_from_step_details(self): + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = ERROR + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_callback_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = MagicMock() + op.callback_details.error = ERROR + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_chained_invoke_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = MagicMock() + op.chained_invoke_details.error = ERROR + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_context_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = MagicMock() + op.context_details.error = ERROR + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_returns_none_when_no_error(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertIsNone(result) + + def test_extract_error_step_details_no_error(self): + """step_details exists but has no error - falls through to callback.""" + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = None + op.callback_details = MagicMock() + op.callback_details.error = ERROR + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + +class TestPluginExecutorIsTerminalStatus(unittest.TestCase): + """Tests for PluginExecutor._is_terminal_status static method.""" + + def test_succeeded_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.SUCCEEDED)) + + def test_failed_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.FAILED)) + + def test_timed_out_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.TIMED_OUT)) + + def test_cancelled_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.CANCELLED)) + + def test_stopped_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.STOPPED)) + + def test_started_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.STARTED)) + + def test_pending_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.PENDING)) + + def test_ready_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.READY)) + + +# endregion PluginExecutor Tests + + +# region Helper Classes + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to all hooks.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + self.calls.append(f"invocation_start:{info.request_id}") + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + self.calls.append(f"invocation_end:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info: OperationEndInfo) -> None: + self.calls.append(f"operation_end:{info.operation_id}") + + def on_user_function_start(self, info: UserFunctionStartInfo) -> None: + self.calls.append(f"user_function_start:{info.operation_id}") + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + self.calls.append(f"user_function_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("boom") + + def on_execution_end(self, info): + raise RuntimeError("boom") + + def on_invocation_start(self, info): + raise RuntimeError("boom") + + def on_invocation_end(self, info): + raise RuntimeError("boom") + + def on_operation_start(self, info): + raise RuntimeError("boom") + + def on_operation_end(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("boom") + + +# endregion Helper Classes + + +if __name__ == "__main__": + unittest.main() diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index 0152ca6c..101784da 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -9,7 +9,7 @@ import time import unittest.mock from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, call, patch, create_autospec import pytest @@ -37,6 +37,10 @@ StateOutput, StepDetails, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, +) from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, CheckpointedResult, @@ -332,7 +336,7 @@ def test_checkpointerd_result_is_pending(): assert result_no_op.is_pending() is False -def test_checkpointerd_result_is_ready(): +def test_checkpointed_result_is_ready(): """Test CheckpointedResult.is_ready method.""" operation = Operation( operation_id="op1", @@ -405,6 +409,7 @@ def test_execution_state_creation(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) assert state.durable_execution_arn == "test_arn" assert state.operations == {} @@ -425,6 +430,7 @@ def test_get_checkpoint_result_success_with_result(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -446,6 +452,7 @@ def test_get_checkpoint_result_success_without_step_details(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -467,6 +474,7 @@ def test_get_checkpoint_result_operation_not_succeeded(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -483,6 +491,7 @@ def test_get_checkpoint_result_operation_not_found(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("nonexistent") @@ -500,6 +509,7 @@ def test_create_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -530,6 +540,7 @@ def test_create_checkpoint_with_none(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with None and is_sync=False enqueues an empty checkpoint @@ -554,6 +565,7 @@ def test_create_checkpoint_with_no_args(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with no args and is_sync=False enqueues an empty checkpoint @@ -582,6 +594,7 @@ def test_get_checkpoint_result_started(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -675,6 +688,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) state.fetch_paginated_operations( @@ -773,6 +787,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -811,6 +826,7 @@ def test_fetch_paginated_operations_logs_error(caplog): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -920,6 +936,7 @@ def test_checkpoint_batch_respects_default_max_items_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -988,6 +1005,7 @@ def test_collect_checkpoint_batch_respects_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1021,6 +1039,7 @@ def test_collect_checkpoint_batch_uses_overflow_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Put operations in overflow queue @@ -1072,6 +1091,7 @@ def test_collect_checkpoint_batch_handles_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue empty checkpoint @@ -1107,6 +1127,7 @@ def test_collect_checkpoint_batch_returns_empty_when_stopped(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal stop before collecting @@ -1128,6 +1149,7 @@ def test_parent_child_relationship_building(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1169,6 +1191,7 @@ def test_descendant_cancellation_when_parent_completes(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1208,6 +1231,7 @@ def test_rejection_of_operations_from_completed_parents(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1257,6 +1281,7 @@ def test_nested_parallel_operations_deep_hierarchy(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build deep hierarchy: grandparent -> parent -> child @@ -1313,6 +1338,7 @@ def test_synchronous_checkpoint_blocks_until_complete(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -1361,6 +1387,7 @@ def test_concurrent_access_to_operations_dictionary(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add initial operation @@ -1430,6 +1457,7 @@ def test_stop_checkpointing_signals_background_thread(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Verify event is not set initially @@ -1523,6 +1551,7 @@ def test_create_checkpoint_sync_with_parent_id(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1574,6 +1603,7 @@ def test_create_checkpoint_sync_rejects_orphaned_operation(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child relationship @@ -1638,6 +1668,7 @@ def test_mark_orphans_handles_cycles(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Manually create a cycle (shouldn't happen in practice, but test defensive code) @@ -1668,6 +1699,7 @@ def test_checkpoint_batches_forever_exception_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation @@ -1715,6 +1747,7 @@ def test_collect_checkpoint_batch_shutdown_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add operation to queue (would be a non-essential async checkpoint in practice) @@ -1744,6 +1777,7 @@ def test_collect_checkpoint_batch_shutdown_empty_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal shutdown with empty queue @@ -1771,6 +1805,7 @@ def test_collect_checkpoint_batch_overflow_put_back(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1816,6 +1851,7 @@ def test_create_checkpoint_sync_with_none_operation_update(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Simulate background processor @@ -1848,6 +1884,7 @@ def test_checkpoint_batches_forever_exception_with_no_sync_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -1887,6 +1924,7 @@ def test_collect_checkpoint_batch_size_limit_during_time_window(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1940,6 +1978,7 @@ def test_collect_checkpoint_batch_respects_max_operations_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1983,6 +2022,7 @@ def test_collect_checkpoint_batch_time_window_expires(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2030,6 +2070,7 @@ def test_collect_checkpoint_batch_empty_overflow_queue_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Ensure overflow queue is empty (it should be by default) @@ -2067,6 +2108,7 @@ def test_collect_checkpoint_batch_overflow_queue_hits_operation_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2106,6 +2148,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2155,6 +2198,7 @@ def test_checkpoint_error_signals_completion_events_with_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation with completion event @@ -2211,6 +2255,7 @@ def test_synchronous_caller_receives_error_on_background_thread_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2288,6 +2333,7 @@ def test_exception_propagates_through_threadpoolexecutor(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue an operation @@ -2321,6 +2367,7 @@ def test_multiple_sync_operations_all_remain_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create multiple synchronous operations @@ -2372,6 +2419,7 @@ def test_async_operations_not_affected_by_error_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -2409,6 +2457,7 @@ def test_mixed_sync_async_operations_only_sync_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create sync operation with completion event @@ -2469,6 +2518,7 @@ def test_create_checkpoint_accepts_is_sync_parameter(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2503,6 +2553,7 @@ def test_create_checkpoint_default_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2549,6 +2600,7 @@ def test_create_checkpoint_explicit_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2590,6 +2642,7 @@ def test_create_checkpoint_is_sync_false_no_completion_event(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2620,6 +2673,7 @@ def test_create_checkpoint_is_sync_false_returns_immediately(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2658,6 +2712,7 @@ def test_create_checkpoint_with_none_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with None (will block) @@ -2694,6 +2749,7 @@ def test_create_checkpoint_no_args_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with no arguments (will block) @@ -2733,6 +2789,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit_final(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2788,6 +2845,7 @@ def test_create_checkpoint_blocks_until_completion_default(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2859,6 +2917,7 @@ def test_create_checkpoint_blocks_until_completion_explicit_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2930,6 +2989,7 @@ def test_create_checkpoint_completion_event_created_and_signaled(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2994,6 +3054,7 @@ def test_create_checkpoint_completion_event_not_signaled_on_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3080,6 +3141,7 @@ def test_create_checkpoint_caller_remains_blocked_on_background_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3162,6 +3224,7 @@ def test_create_checkpoint_multiple_sync_calls_all_block(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) num_callers = 3 @@ -3238,6 +3301,7 @@ def test_create_checkpoint_sync_with_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Track timing and completion @@ -3296,6 +3360,7 @@ def test_create_checkpoint_sync_success(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3330,6 +3395,7 @@ def test_create_checkpoint_sync_unwraps_background_thread_error(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3363,6 +3429,7 @@ def test_create_checkpoint_sync_always_synchronous(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3400,6 +3467,7 @@ def test_state_replay_mode(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3433,6 +3501,7 @@ def test_state_replay_mode_with_timed_out(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3464,6 +3533,7 @@ def test_collect_checkpoint_batch_coalesces_many_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3497,6 +3567,7 @@ def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3536,6 +3607,7 @@ def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3576,6 +3648,7 @@ def test_checkpoint_batches_forever_single_api_call_for_many_empty_checkpoints() initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3624,6 +3697,7 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3676,6 +3750,7 @@ def test_execution_state_get_execution_operation_no_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3707,6 +3782,7 @@ def test_initial_execution_state_get_execution_operation_wrong_type(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3743,8 +3819,443 @@ def test_initial_execution_state_get_input_payload_none(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) result = state.get_input_payload() assert result is None + + +# region Plugin Executor Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append("execution_end") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append("invocation_end") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_user_function_start(self, info): + self.calls.append(f"user_function_start:{info.operation_id}") + + def on_user_function_end(self, info): + self.calls.append(f"user_function_end:{info.operation_id}") + + +def test_execution_state_accepts_plugin_executor_parameter(): + """Test that ExecutionState can be created with a plugin_executor parameter.""" + mock_client = Mock(spec=LambdaClient) + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + assert state._plugin_executor is plugin_executor + + +def test_plugin_executor_on_operation_action_called_on_checkpoint(): + """Test that plugin_executor.on_operation_action is called for each update after checkpoint.""" + mock_client = create_autospec(LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + # Start background thread + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action is called for START updates + assert "operation_start:step-1" in plugin.calls + + +def test_plugin_executor_on_operation_update_called_for_terminal_operations(): + """Test that plugin_executor.on_operation_update is called for terminal operations.""" + mock_client = create_autospec(LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="my-step", + payload='"done"', + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + assert "operation_end:step-1" in plugin.calls + + +def test_plugin_executor_not_called_for_non_terminal_operations(): + """Test that plugin_executor.on_operation_update does not fire for non-terminal operations.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return a STARTED step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=None, + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action fires for START + assert "operation_start:step-1" in plugin.calls + # But on_operation_update should NOT fire operation_end for STARTED status + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +def test_plugin_executor_called_for_multiple_updates_in_batch(): + """Test that plugin_executor is called for each update in a batch.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return multiple operations from checkpoint + step_op1 = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result1"'), + ) + step_op2 = Operation( + operation_id="step-2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result2"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op1, step_op2], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + config = CheckpointBatcherConfig( + max_batch_time_seconds=0.2, + max_batch_operations=10, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + batcher_config=config, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + op1 = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-1", + ) + op2 = OperationUpdate( + operation_id="step-2", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-2", + ) + # Enqueue both without blocking so they batch together + state.create_checkpoint(op1, is_sync=False) + state.create_checkpoint(op2, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Both operations should have triggered on_operation_action + assert "operation_start:step-1" in plugin.calls + assert "operation_start:step-2" in plugin.calls + # Both terminal operations should have triggered on_operation_update + assert "operation_end:step-1" in plugin.calls + assert "operation_end:step-2" in plugin.calls + + +def test_plugin_executor_not_called_on_checkpoint_failure(): + """Test that plugin_executor is NOT called when checkpoint API fails.""" + mock_client = create_autospec(spec=LambdaClient) + mock_client.checkpoint.side_effect = RuntimeError("API error") + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + + with pytest.raises(BackgroundThreadError): + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Plugin should NOT have been called since checkpoint failed + assert "operation_start:step-1" not in plugin.calls + assert "operation_end:step-1" not in plugin.calls + + +def test_plugin_executor_exception_does_not_break_checkpointing(): + """Test that a plugin exception does not break the checkpoint processing loop.""" + mock_client = create_autospec(spec=LambdaClient) + + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + class _ExplodingPlugin(DurableExecutionPlugin): + def on_operation_start(self, info): + raise RuntimeError("plugin exploded") + + def on_operation_end(self, info): + raise RuntimeError("plugin exploded") + + exploding_plugin = _ExplodingPlugin() + plugin_executor = PluginExecutor(plugins=[exploding_plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + # Should not raise even though plugin explodes + state.create_checkpoint(operation_update, is_sync=True) + + # Checkpoint should still have been called successfully + assert mock_client.checkpoint.call_count == 1 + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_not_called_for_pending_operations(): + """Test that plugin_executor.on_operation_update fires on_user_function_end for PENDING operations.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return a PENDING step operation from checkpoint (simulates a retry scenario) + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails( + attempt=1, + result=None, + error=ErrorObject( + message="transient failure", + type="RetryableError", + data=None, + stack_trace=None, + ), + ), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # operation_end should NOT fire for PENDING (only for terminal statuses) + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +# endregion Plugin Executor Integration Tests