Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ class AgentCoreMemoryConfig(BaseModel):
persistence_mode: Controls what gets persisted to AgentCore Memory.
FULL (default): persist everything. NONE: disable all persistence while keeping
local state management and memory injection working.
async_mode: When True, the session manager registers async hook callbacks that
offload the per-turn boto3 calls (append_message, sync_agent,
retrieve_customer_context, and buffer flushes) to a thread via
asyncio.to_thread, keeping the asyncio event loop unblocked. Intended for
async agent runtimes (e.g. Agent.stream_async() in a WebSocket server).
Default is False (existing synchronous behavior, unchanged).

Requires async invocation (stream_async / invoke_async). Sync agent() calls
will raise RuntimeError from Strands' hook registry because it refuses to
dispatch coroutine callbacks through the sync path.

Note: this does NOT cover agent initialization. Strands disallows async
callbacks for AgentInitializedEvent, so the read_session / read_agent /
list_messages calls that run during Agent(...) construction still block
the calling thread. If that matters, construct the Agent off-loop
(e.g. `await asyncio.to_thread(Agent, ...)`).
"""

memory_id: str = Field(min_length=1)
Expand All @@ -81,6 +97,7 @@ class AgentCoreMemoryConfig(BaseModel):
default_metadata: Optional[Dict[str, Any]] = None
metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None
persistence_mode: PersistenceMode = Field(default=PersistenceMode.FULL)
async_mode: bool = Field(default=False)

@field_validator("default_metadata", mode="before")
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""AgentCore Memory-based session manager for Bedrock AgentCore Memory integration."""

import asyncio
import json
import logging
import threading
Expand All @@ -10,7 +11,13 @@

import boto3
from botocore.config import Config as BotocoreConfig
from strands.experimental.hooks.multiagent.events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
MultiAgentInitializedEvent,
)
from strands.hooks import AfterInvocationEvent, MessageAddedEvent
from strands.hooks.events import AgentInitializedEvent
from strands.hooks.registry import HookRegistry
from strands.session.repository_session_manager import RepositorySessionManager
from strands.session.session_repository import SessionRepository
Expand Down Expand Up @@ -906,16 +913,79 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig):
def register_hooks(self, registry: HookRegistry, **kwargs) -> None:
"""Register additional hooks.

In sync mode (the default), delegates to the base class and adds the
retrieve_customer_context + batching callbacks synchronously, preserving
existing behavior exactly.

In async mode, registers async callbacks that wrap every per-turn
boto3-backed operation (append_message, sync_agent, buffer flushes,
customer-context retrieval) with asyncio.to_thread, so the asyncio
event loop stays free while boto3 is blocking on the network.

Note: AgentInitializedEvent cannot be async per Strands' HookRegistry,
so agent restoration (read_session / read_agent / list_messages) still
blocks the calling thread in async mode — see AgentCoreMemoryConfig
docstring for mitigations.

Args:
registry (HookRegistry): The hook registry to register callbacks with.
**kwargs: Additional keyword arguments.
"""
RepositorySessionManager.register_hooks(self, registry, **kwargs)
registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event))
if not self.config.async_mode:
RepositorySessionManager.register_hooks(self, registry, **kwargs)
registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event))

# Only register AfterInvocationEvent hook when batching is enabled
if self.config.batch_size > 1:
registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages())
return

# Async mode: register async callbacks that offload the existing sync
# methods to a worker thread via asyncio.to_thread. AgentInitializedEvent
# must stay sync (Strands disallows async callbacks on this event; see
# strands/hooks/registry.py:174).
logger.warning(
"AgentCoreMemorySessionManager async_mode=True: the agent must be invoked "
"via the async path (e.g. agent.stream_async(...) or agent.invoke_async(...)). "
"Sync invocation will raise RuntimeError from Strands' hook registry."
)

registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))

async def _on_message_added_persist(event: MessageAddedEvent) -> None:
await asyncio.to_thread(self.append_message, event.message, event.agent)
await asyncio.to_thread(self.sync_agent, event.agent)

async def _on_message_added_retrieve(event: MessageAddedEvent) -> None:
await asyncio.to_thread(self.retrieve_customer_context, event)

async def _on_after_invocation_sync(event: AfterInvocationEvent) -> None:
await asyncio.to_thread(self.sync_agent, event.agent)

registry.add_callback(MessageAddedEvent, _on_message_added_persist)
registry.add_callback(AfterInvocationEvent, _on_after_invocation_sync)
registry.add_callback(MessageAddedEvent, _on_message_added_retrieve)

# Only register AfterInvocationEvent hook when batching is enabled
if self.config.batch_size > 1:
registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages())

async def _on_after_invocation_flush(event: AfterInvocationEvent) -> None:
await asyncio.to_thread(self._flush_messages)

registry.add_callback(AfterInvocationEvent, _on_after_invocation_flush)

# Register multi-agent callbacks so async-mode parity matches sync-mode
async def _on_multi_agent_initialized(event: MultiAgentInitializedEvent) -> None:
await asyncio.to_thread(self.initialize_multi_agent, event.source)

async def _on_after_node_call(event: AfterNodeCallEvent) -> None:
await asyncio.to_thread(self.sync_multi_agent, event.source)

async def _on_after_multi_agent_invocation(event: AfterMultiAgentInvocationEvent) -> None:
await asyncio.to_thread(self.sync_multi_agent, event.source)

registry.add_callback(MultiAgentInitializedEvent, _on_multi_agent_initialized)
registry.add_callback(AfterNodeCallEvent, _on_after_node_call)
registry.add_callback(AfterMultiAgentInvocationEvent, _on_after_multi_agent_invocation)

@override
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for AgentCoreMemorySessionManager."""

import asyncio
import inspect
import logging
import time
from datetime import datetime, timezone
Expand All @@ -9,6 +11,11 @@
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError
from strands.agent.agent import Agent
from strands.experimental.hooks.multiagent.events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
MultiAgentInitializedEvent,
)
from strands.hooks import AfterInvocationEvent, MessageAddedEvent
from strands.hooks.registry import HookRegistry
from strands.types.exceptions import SessionException
Expand Down Expand Up @@ -3580,3 +3587,123 @@ def test_retrieve_customer_context_works(self, mock_memory_client):

mock_memory_client.retrieve_memories.assert_called_once()
assert "<user_context>" in mock_agent.messages[0]["content"][0]["text"]


class TestAsyncMode:
"""Tests for async_mode: callbacks must not block the event loop."""

def test_async_mode_defaults_to_false(self, agentcore_config):
assert agentcore_config.async_mode is False

def test_sync_mode_registers_sync_callbacks(self, mock_memory_client):
"""async_mode=False: all MessageAddedEvent/AfterInvocationEvent callbacks are sync."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=False)
manager = _create_session_manager(config, mock_memory_client)
registry = HookRegistry()
manager.register_hooks(registry)

for event_type in (MessageAddedEvent, AfterInvocationEvent):
for cb in registry.get_callbacks_for(
event_type(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]})
if event_type is MessageAddedEvent
else event_type(agent=Mock())
):
assert not inspect.iscoroutinefunction(cb), f"Sync mode leaked an async callback for {event_type}"

def test_async_mode_registers_async_callbacks(self, mock_memory_client):
"""async_mode=True: MessageAddedEvent and AfterInvocationEvent callbacks are coroutine functions."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True)
manager = _create_session_manager(config, mock_memory_client)
registry = HookRegistry()
manager.register_hooks(registry)

msg_callbacks = registry.get_callbacks_for(
MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]})
)
assert msg_callbacks, "No MessageAddedEvent callbacks registered in async mode"
assert all(inspect.iscoroutinefunction(cb) for cb in msg_callbacks)

after_callbacks = registry.get_callbacks_for(AfterInvocationEvent(agent=Mock()))
assert after_callbacks, "No AfterInvocationEvent callbacks registered in async mode"
assert all(inspect.iscoroutinefunction(cb) for cb in after_callbacks)

async def test_async_mode_does_not_block_event_loop(self, mock_memory_client):
"""The async hooks run boto3 on a worker thread, so the event loop can make progress concurrently."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True)
manager = _create_session_manager(config, mock_memory_client)

# Simulate each sync session-manager method blocking on boto3.
def slow_append_message(message, agent, **kwargs):
time.sleep(0.2)

def slow_sync_agent(agent, **kwargs):
time.sleep(0.2)

manager.append_message = slow_append_message
manager.sync_agent = slow_sync_agent

registry = HookRegistry()
manager.register_hooks(registry)

persist_callbacks = [
cb
for cb in registry.get_callbacks_for(
MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]})
)
if asyncio.iscoroutinefunction(cb)
]
assert persist_callbacks

event = MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "hello"}]})

# Ticker proves the event loop made progress while the hook awaited to_thread.
ticks = 0

async def ticker():
nonlocal ticks
while True:
await asyncio.sleep(0.01)
ticks += 1

ticker_task = asyncio.create_task(ticker())
try:
# Run the persist callback (append_message + sync_agent); both sleep 0.2s on a worker thread.
await persist_callbacks[0](event)
finally:
ticker_task.cancel()

assert ticks > 5, f"Event loop was blocked; only {ticks} ticks recorded"

async def test_async_mode_batching_registers_flush_callback(self, mock_memory_client):
"""async_mode=True with batch_size>1: AfterInvocationEvent gets both sync_agent and flush callbacks."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True)
manager = _create_session_manager(config, mock_memory_client)
registry = HookRegistry()
manager.register_hooks(registry)

after_callbacks = list(registry.get_callbacks_for(AfterInvocationEvent(agent=Mock())))
assert len(after_callbacks) == 2
assert all(asyncio.iscoroutinefunction(cb) for cb in after_callbacks)

def test_async_mode_registers_multi_agent_callbacks(self, mock_memory_client):
"""async_mode=True: multi-agent events get async callbacks (parity with sync mode)."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True)
manager = _create_session_manager(config, mock_memory_client)
registry = HookRegistry()
manager.register_hooks(registry)

for event_type in (MultiAgentInitializedEvent, AfterNodeCallEvent, AfterMultiAgentInvocationEvent):
callbacks = registry._registered_callbacks.get(event_type, [])
assert callbacks, f"No callbacks registered for {event_type.__name__}"
assert all(asyncio.iscoroutinefunction(cb) for cb in callbacks)

def test_async_mode_logs_sync_invocation_warning(self, mock_memory_client, caplog):
"""async_mode=True emits a WARNING at register_hooks time pointing users to stream_async/invoke_async."""
config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True)
manager = _create_session_manager(config, mock_memory_client)
registry = HookRegistry()

with caplog.at_level(logging.WARNING, logger="bedrock_agentcore.memory.integrations.strands.session_manager"):
manager.register_hooks(registry)

assert any("async_mode=True" in rec.message and "stream_async" in rec.message for rec in caplog.records)
Loading