Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 50 additions & 37 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,49 +777,62 @@ async def _run_loop(
Yields:
Events from the event loop cycle.
"""
before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages)
)
messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages
current_messages: Messages | None = messages

agent_result: AgentResult | None = None
try:
yield InitEventLoopEvent()
while current_messages is not None:
before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=current_messages)
)
current_messages = (
before_invocation_event.messages if before_invocation_event.messages is not None else current_messages
)

await self._append_messages(*messages)
agent_result: AgentResult | None = None
try:
yield InitEventLoopEvent()

structured_output_context = StructuredOutputContext(
structured_output_model or self._default_structured_output_model,
structured_output_prompt=structured_output_prompt or self._structured_output_prompt,
)
await self._append_messages(*current_messages)

# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(invocation_state, structured_output_context)
async for event in events:
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
isinstance(event, ModelStreamChunkEvent)
and event.chunk
and event.chunk.get("redactContent")
and event.chunk["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = self._redact_user_content(
self.messages[-1]["content"], str(event.chunk["redactContent"]["redactUserContentMessage"])
)
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
yield event
structured_output_context = StructuredOutputContext(
structured_output_model or self._default_structured_output_model,
structured_output_prompt=structured_output_prompt or self._structured_output_prompt,
)

# Capture the result from the final event if available
if isinstance(event, EventLoopStopEvent):
agent_result = AgentResult(*event["stop"])
# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(invocation_state, structured_output_context)
async for event in events:
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
isinstance(event, ModelStreamChunkEvent)
and event.chunk
and event.chunk.get("redactContent")
and event.chunk["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = self._redact_user_content(
self.messages[-1]["content"],
str(event.chunk["redactContent"]["redactUserContentMessage"]),
)
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
yield event

# Capture the result from the final event if available
if isinstance(event, EventLoopStopEvent):
agent_result = AgentResult(*event["stop"])

finally:
self.conversation_manager.apply_management(self)
await self.hooks.invoke_callbacks_async(
AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result)
)
finally:
self.conversation_manager.apply_management(self)
after_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result)
)

# Convert resume input to messages for next iteration, or None to stop
if after_invocation_event.resume is not None:
logger.debug("resume=<True> | hook requested agent resume with new input")
current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume)
else:
current_messages = None

async def _execute_event_loop_cycle(
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
Expand Down
15 changes: 15 additions & 0 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if TYPE_CHECKING:
from ..agent.agent_result import AgentResult

from ..types.agent import AgentInput
from ..types.content import Message, Messages
from ..types.interrupt import _Interruptible
from ..types.streaming import StopReason
Expand Down Expand Up @@ -78,17 +79,31 @@ class AfterInvocationEvent(HookEvent):
- Agent.stream_async
- Agent.structured_output

Resume:
When ``resume`` is set to a non-None value by a hook callback, the agent will
automatically re-invoke itself with the provided input. This enables hooks to
implement autonomous looping patterns where the agent continues processing
based on its previous result. The resume triggers a full new invocation cycle
including ``BeforeInvocationEvent``.

Attributes:
invocation_state: State and configuration passed through the agent invocation.
This can include shared context for multi-agent coordination, request tracking,
and dynamic configuration.
result: The result of the agent invocation, if available.
This will be None when invoked from structured_output methods, as those return typed output directly rather
than AgentResult.
resume: When set to a non-None agent input by a hook callback, the agent will
re-invoke itself with this input. The value can be any valid AgentInput
(str, content blocks, messages, etc.). Defaults to None (no resume).
"""

invocation_state: dict[str, Any] = field(default_factory=dict)
result: "AgentResult | None" = None
resume: AgentInput = None

def _can_write(self, name: str) -> bool:
return name == "resume"

@property
def should_reverse_callbacks(self) -> bool:
Expand Down
30 changes: 30 additions & 0 deletions tests/strands/agent/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,33 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes
"""Test that BeforeInvocationEvent.agent is not writable."""
with pytest.raises(AttributeError, match="Property agent is not writable"):
start_request_event_with_messages.agent = Mock()


def test_after_invocation_event_resume_defaults_to_none(agent):
"""Test that AfterInvocationEvent.resume defaults to None."""
event = AfterInvocationEvent(agent=agent, result=None)
assert event.resume is None


def test_after_invocation_event_resume_is_writable(agent):
"""Test that AfterInvocationEvent.resume can be set by hooks."""
event = AfterInvocationEvent(agent=agent, result=None)
event.resume = "continue with this input"
assert event.resume == "continue with this input"


def test_after_invocation_event_resume_accepts_various_input_types(agent):
"""Test that resume accepts all AgentInput types."""
event = AfterInvocationEvent(agent=agent, result=None)

# String input
event.resume = "hello"
assert event.resume == "hello"

# Content block list
event.resume = [{"text": "hello"}]
assert event.resume == [{"text": "hello"}]

# None to stop
event.resume = None
assert event.resume is None
117 changes: 117 additions & 0 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,120 @@ async def capture_messages_hook(event: BeforeInvocationEvent):

# structured_output_async uses deprecated path that doesn't pass messages
assert received_messages is None


def test_after_invocation_resume_triggers_new_invocation():
"""Test that setting resume on AfterInvocationEvent re-invokes the agent."""
mock_provider = MockedModelProvider(
[
{"role": "assistant", "content": [{"text": "First response"}]},
{"role": "assistant", "content": [{"text": "Second response"}]},
]
)

resume_count = 0

async def resume_once(event: AfterInvocationEvent):
nonlocal resume_count
if resume_count == 0:
resume_count += 1
event.resume = "continue"

agent = Agent(model=mock_provider)
agent.hooks.add_callback(AfterInvocationEvent, resume_once)

result = agent("start")

# Agent should have been invoked twice
assert resume_count == 1
assert result.message["content"][0]["text"] == "Second response"
# 4 messages: user1, assistant1, user2 (resume), assistant2
assert len(agent.messages) == 4
assert agent.messages[0]["content"][0]["text"] == "start"
assert agent.messages[2]["content"][0]["text"] == "continue"


def test_after_invocation_resume_none_does_not_loop():
"""Test that resume=None (default) does not re-invoke the agent."""
mock_provider = MockedModelProvider(
[
{"role": "assistant", "content": [{"text": "Only response"}]},
]
)

call_count = 0

async def no_resume(event: AfterInvocationEvent):
nonlocal call_count
call_count += 1
# Don't set resume - should remain None

agent = Agent(model=mock_provider)
agent.hooks.add_callback(AfterInvocationEvent, no_resume)

result = agent("hello")

assert call_count == 1
assert result.message["content"][0]["text"] == "Only response"


def test_after_invocation_resume_fires_before_invocation_event():
"""Test that resume triggers BeforeInvocationEvent on each iteration."""
mock_provider = MockedModelProvider(
[
{"role": "assistant", "content": [{"text": "First"}]},
{"role": "assistant", "content": [{"text": "Second"}]},
]
)

before_invocation_count = 0
after_invocation_count = 0

async def count_before(event: BeforeInvocationEvent):
nonlocal before_invocation_count
before_invocation_count += 1

async def resume_once(event: AfterInvocationEvent):
nonlocal after_invocation_count
after_invocation_count += 1
if after_invocation_count == 1:
event.resume = "next"

agent = Agent(model=mock_provider)
agent.hooks.add_callback(BeforeInvocationEvent, count_before)
agent.hooks.add_callback(AfterInvocationEvent, resume_once)

agent("start")

# BeforeInvocationEvent should fire for both the initial and resumed invocation
assert before_invocation_count == 2
assert after_invocation_count == 2


def test_after_invocation_resume_multiple_times():
"""Test that resume can chain multiple re-invocations."""
mock_provider = MockedModelProvider(
[
{"role": "assistant", "content": [{"text": "Response 1"}]},
{"role": "assistant", "content": [{"text": "Response 2"}]},
{"role": "assistant", "content": [{"text": "Response 3"}]},
]
)

resume_count = 0

async def resume_twice(event: AfterInvocationEvent):
nonlocal resume_count
if resume_count < 2:
resume_count += 1
event.resume = f"iteration {resume_count + 1}"

agent = Agent(model=mock_provider)
agent.hooks.add_callback(AfterInvocationEvent, resume_twice)

result = agent("iteration 1")

assert resume_count == 2
assert result.message["content"][0]["text"] == "Response 3"
# 6 messages: 3 user + 3 assistant
assert len(agent.messages) == 6
Loading