diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 3fc8901b..a42e824e 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -68,6 +68,22 @@ class AgentCoreMemoryConfig(BaseModel): persistence_mode: Controls what gets persisted to AgentCore Memory. FULL (default): persist everything. NONE: disable all persistence while keeping local state management and memory injection working. + async_mode: When True, the session manager registers async hook callbacks that + offload the per-turn boto3 calls (append_message, sync_agent, + retrieve_customer_context, and buffer flushes) to a thread via + asyncio.to_thread, keeping the asyncio event loop unblocked. Intended for + async agent runtimes (e.g. Agent.stream_async() in a WebSocket server). + Default is False (existing synchronous behavior, unchanged). + + Requires async invocation (stream_async / invoke_async). Sync agent() calls + will raise RuntimeError from Strands' hook registry because it refuses to + dispatch coroutine callbacks through the sync path. + + Note: this does NOT cover agent initialization. Strands disallows async + callbacks for AgentInitializedEvent, so the read_session / read_agent / + list_messages calls that run during Agent(...) construction still block + the calling thread. If that matters, construct the Agent off-loop + (e.g. `await asyncio.to_thread(Agent, ...)`). """ memory_id: str = Field(min_length=1) @@ -81,6 +97,7 @@ class AgentCoreMemoryConfig(BaseModel): default_metadata: Optional[Dict[str, Any]] = None metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None persistence_mode: PersistenceMode = Field(default=PersistenceMode.FULL) + async_mode: bool = Field(default=False) @field_validator("default_metadata", mode="before") @classmethod diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 6bfaca9d..6e72df53 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -1,5 +1,6 @@ """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.""" +import asyncio import json import logging import threading @@ -10,7 +11,13 @@ import boto3 from botocore.config import Config as BotocoreConfig +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from strands.hooks import AfterInvocationEvent, MessageAddedEvent +from strands.hooks.events import AgentInitializedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -906,16 +913,79 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): def register_hooks(self, registry: HookRegistry, **kwargs) -> None: """Register additional hooks. + In sync mode (the default), delegates to the base class and adds the + retrieve_customer_context + batching callbacks synchronously, preserving + existing behavior exactly. + + In async mode, registers async callbacks that wrap every per-turn + boto3-backed operation (append_message, sync_agent, buffer flushes, + customer-context retrieval) with asyncio.to_thread, so the asyncio + event loop stays free while boto3 is blocking on the network. + + Note: AgentInitializedEvent cannot be async per Strands' HookRegistry, + so agent restoration (read_session / read_agent / list_messages) still + blocks the calling thread in async mode — see AgentCoreMemoryConfig + docstring for mitigations. + Args: registry (HookRegistry): The hook registry to register callbacks with. **kwargs: Additional keyword arguments. """ - RepositorySessionManager.register_hooks(self, registry, **kwargs) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + if not self.config.async_mode: + RepositorySessionManager.register_hooks(self, registry, **kwargs) + registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + + # Only register AfterInvocationEvent hook when batching is enabled + if self.config.batch_size > 1: + registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + return + + # Async mode: register async callbacks that offload the existing sync + # methods to a worker thread via asyncio.to_thread. AgentInitializedEvent + # must stay sync (Strands disallows async callbacks on this event; see + # strands/hooks/registry.py:174). + logger.warning( + "AgentCoreMemorySessionManager async_mode=True: the agent must be invoked " + "via the async path (e.g. agent.stream_async(...) or agent.invoke_async(...)). " + "Sync invocation will raise RuntimeError from Strands' hook registry." + ) + + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + + async def _on_message_added_persist(event: MessageAddedEvent) -> None: + await asyncio.to_thread(self.append_message, event.message, event.agent) + await asyncio.to_thread(self.sync_agent, event.agent) + + async def _on_message_added_retrieve(event: MessageAddedEvent) -> None: + await asyncio.to_thread(self.retrieve_customer_context, event) + + async def _on_after_invocation_sync(event: AfterInvocationEvent) -> None: + await asyncio.to_thread(self.sync_agent, event.agent) + + registry.add_callback(MessageAddedEvent, _on_message_added_persist) + registry.add_callback(AfterInvocationEvent, _on_after_invocation_sync) + registry.add_callback(MessageAddedEvent, _on_message_added_retrieve) - # Only register AfterInvocationEvent hook when batching is enabled if self.config.batch_size > 1: - registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + + async def _on_after_invocation_flush(event: AfterInvocationEvent) -> None: + await asyncio.to_thread(self._flush_messages) + + registry.add_callback(AfterInvocationEvent, _on_after_invocation_flush) + + # Register multi-agent callbacks so async-mode parity matches sync-mode + async def _on_multi_agent_initialized(event: MultiAgentInitializedEvent) -> None: + await asyncio.to_thread(self.initialize_multi_agent, event.source) + + async def _on_after_node_call(event: AfterNodeCallEvent) -> None: + await asyncio.to_thread(self.sync_multi_agent, event.source) + + async def _on_after_multi_agent_invocation(event: AfterMultiAgentInvocationEvent) -> None: + await asyncio.to_thread(self.sync_multi_agent, event.source) + + registry.add_callback(MultiAgentInitializedEvent, _on_multi_agent_initialized) + registry.add_callback(AfterNodeCallEvent, _on_after_node_call) + registry.add_callback(AfterMultiAgentInvocationEvent, _on_after_multi_agent_invocation) @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 2daca61b..8fa586f0 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1,5 +1,7 @@ """Tests for AgentCoreMemorySessionManager.""" +import asyncio +import inspect import logging import time from datetime import datetime, timezone @@ -9,6 +11,11 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError from strands.agent.agent import Agent +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from strands.hooks import AfterInvocationEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.types.exceptions import SessionException @@ -3580,3 +3587,123 @@ def test_retrieve_customer_context_works(self, mock_memory_client): mock_memory_client.retrieve_memories.assert_called_once() assert "" in mock_agent.messages[0]["content"][0]["text"] + + +class TestAsyncMode: + """Tests for async_mode: callbacks must not block the event loop.""" + + def test_async_mode_defaults_to_false(self, agentcore_config): + assert agentcore_config.async_mode is False + + def test_sync_mode_registers_sync_callbacks(self, mock_memory_client): + """async_mode=False: all MessageAddedEvent/AfterInvocationEvent callbacks are sync.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=False) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + for event_type in (MessageAddedEvent, AfterInvocationEvent): + for cb in registry.get_callbacks_for( + event_type(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + if event_type is MessageAddedEvent + else event_type(agent=Mock()) + ): + assert not inspect.iscoroutinefunction(cb), f"Sync mode leaked an async callback for {event_type}" + + def test_async_mode_registers_async_callbacks(self, mock_memory_client): + """async_mode=True: MessageAddedEvent and AfterInvocationEvent callbacks are coroutine functions.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + msg_callbacks = registry.get_callbacks_for( + MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + ) + assert msg_callbacks, "No MessageAddedEvent callbacks registered in async mode" + assert all(inspect.iscoroutinefunction(cb) for cb in msg_callbacks) + + after_callbacks = registry.get_callbacks_for(AfterInvocationEvent(agent=Mock())) + assert after_callbacks, "No AfterInvocationEvent callbacks registered in async mode" + assert all(inspect.iscoroutinefunction(cb) for cb in after_callbacks) + + async def test_async_mode_does_not_block_event_loop(self, mock_memory_client): + """The async hooks run boto3 on a worker thread, so the event loop can make progress concurrently.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + + # Simulate each sync session-manager method blocking on boto3. + def slow_append_message(message, agent, **kwargs): + time.sleep(0.2) + + def slow_sync_agent(agent, **kwargs): + time.sleep(0.2) + + manager.append_message = slow_append_message + manager.sync_agent = slow_sync_agent + + registry = HookRegistry() + manager.register_hooks(registry) + + persist_callbacks = [ + cb + for cb in registry.get_callbacks_for( + MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + ) + if asyncio.iscoroutinefunction(cb) + ] + assert persist_callbacks + + event = MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "hello"}]}) + + # Ticker proves the event loop made progress while the hook awaited to_thread. + ticks = 0 + + async def ticker(): + nonlocal ticks + while True: + await asyncio.sleep(0.01) + ticks += 1 + + ticker_task = asyncio.create_task(ticker()) + try: + # Run the persist callback (append_message + sync_agent); both sleep 0.2s on a worker thread. + await persist_callbacks[0](event) + finally: + ticker_task.cancel() + + assert ticks > 5, f"Event loop was blocked; only {ticks} ticks recorded" + + async def test_async_mode_batching_registers_flush_callback(self, mock_memory_client): + """async_mode=True with batch_size>1: AfterInvocationEvent gets both sync_agent and flush callbacks.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + after_callbacks = list(registry.get_callbacks_for(AfterInvocationEvent(agent=Mock()))) + assert len(after_callbacks) == 2 + assert all(asyncio.iscoroutinefunction(cb) for cb in after_callbacks) + + def test_async_mode_registers_multi_agent_callbacks(self, mock_memory_client): + """async_mode=True: multi-agent events get async callbacks (parity with sync mode).""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + for event_type in (MultiAgentInitializedEvent, AfterNodeCallEvent, AfterMultiAgentInvocationEvent): + callbacks = registry._registered_callbacks.get(event_type, []) + assert callbacks, f"No callbacks registered for {event_type.__name__}" + assert all(asyncio.iscoroutinefunction(cb) for cb in callbacks) + + def test_async_mode_logs_sync_invocation_warning(self, mock_memory_client, caplog): + """async_mode=True emits a WARNING at register_hooks time pointing users to stream_async/invoke_async.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + + with caplog.at_level(logging.WARNING, logger="bedrock_agentcore.memory.integrations.strands.session_manager"): + manager.register_hooks(registry) + + assert any("async_mode=True" in rec.message and "stream_async" in rec.message for rec in caplog.records)