diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ebead3b7d..c7a2ec7e7 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -777,49 +777,62 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( - BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages) - ) - messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages + current_messages: Messages | None = messages - agent_result: AgentResult | None = None - try: - yield InitEventLoopEvent() + while current_messages is not None: + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=current_messages) + ) + current_messages = ( + before_invocation_event.messages if before_invocation_event.messages is not None else current_messages + ) - await self._append_messages(*messages) + agent_result: AgentResult | None = None + try: + yield InitEventLoopEvent() - structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model, - structured_output_prompt=structured_output_prompt or self._structured_output_prompt, - ) + await self._append_messages(*current_messages) - # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state, structured_output_context) - async for event in events: - # Signal from the model provider that the message sent by the user should be redacted, - # likely due to a guardrail. - if ( - isinstance(event, ModelStreamChunkEvent) - and event.chunk - and event.chunk.get("redactContent") - and event.chunk["redactContent"].get("redactUserContentMessage") - ): - self.messages[-1]["content"] = self._redact_user_content( - self.messages[-1]["content"], str(event.chunk["redactContent"]["redactUserContentMessage"]) - ) - if self._session_manager: - self._session_manager.redact_latest_message(self.messages[-1], self) - yield event + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, + ) - # Capture the result from the final event if available - if isinstance(event, EventLoopStopEvent): - agent_result = AgentResult(*event["stop"]) + # Execute the event loop cycle with retry logic for context limits + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) + async for event in events: + # Signal from the model provider that the message sent by the user should be redacted, + # likely due to a guardrail. + if ( + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") + ): + self.messages[-1]["content"] = self._redact_user_content( + self.messages[-1]["content"], + str(event.chunk["redactContent"]["redactUserContentMessage"]), + ) + if self._session_manager: + self._session_manager.redact_latest_message(self.messages[-1], self) + yield event + + # Capture the result from the final event if available + if isinstance(event, EventLoopStopEvent): + agent_result = AgentResult(*event["stop"]) - finally: - self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async( - AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) - ) + finally: + self.conversation_manager.apply_management(self) + after_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) + + # Convert resume input to messages for next iteration, or None to stop + if after_invocation_event.resume is not None: + logger.debug("resume= | hook requested agent resume with new input") + current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume) + else: + current_messages = None async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8d3e5d280..9186e0e70 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult +from ..types.agent import AgentInput from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason @@ -78,6 +79,13 @@ class AfterInvocationEvent(HookEvent): - Agent.stream_async - Agent.structured_output + Resume: + When ``resume`` is set to a non-None value by a hook callback, the agent will + automatically re-invoke itself with the provided input. This enables hooks to + implement autonomous looping patterns where the agent continues processing + based on its previous result. The resume triggers a full new invocation cycle + including ``BeforeInvocationEvent``. + Attributes: invocation_state: State and configuration passed through the agent invocation. This can include shared context for multi-agent coordination, request tracking, @@ -85,10 +93,17 @@ class AfterInvocationEvent(HookEvent): result: The result of the agent invocation, if available. This will be None when invoked from structured_output methods, as those return typed output directly rather than AgentResult. + resume: When set to a non-None agent input by a hook callback, the agent will + re-invoke itself with this input. The value can be any valid AgentInput + (str, content blocks, messages, etc.). Defaults to None (no resume). """ invocation_state: dict[str, Any] = field(default_factory=dict) result: "AgentResult | None" = None + resume: AgentInput = None + + def _can_write(self, name: str) -> bool: + return name == "resume" @property def should_reverse_callbacks(self) -> bool: diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index de551d137..0e03fbbcd 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -230,3 +230,33 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes """Test that BeforeInvocationEvent.agent is not writable.""" with pytest.raises(AttributeError, match="Property agent is not writable"): start_request_event_with_messages.agent = Mock() + + +def test_after_invocation_event_resume_defaults_to_none(agent): + """Test that AfterInvocationEvent.resume defaults to None.""" + event = AfterInvocationEvent(agent=agent, result=None) + assert event.resume is None + + +def test_after_invocation_event_resume_is_writable(agent): + """Test that AfterInvocationEvent.resume can be set by hooks.""" + event = AfterInvocationEvent(agent=agent, result=None) + event.resume = "continue with this input" + assert event.resume == "continue with this input" + + +def test_after_invocation_event_resume_accepts_various_input_types(agent): + """Test that resume accepts all AgentInput types.""" + event = AfterInvocationEvent(agent=agent, result=None) + + # String input + event.resume = "hello" + assert event.resume == "hello" + + # Content block list + event.resume = [{"text": "hello"}] + assert event.resume == [{"text": "hello"}] + + # None to stop + event.resume = None + assert event.resume is None diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 4397b9628..4ce971b03 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -694,3 +694,120 @@ async def capture_messages_hook(event: BeforeInvocationEvent): # structured_output_async uses deprecated path that doesn't pass messages assert received_messages is None + + +def test_after_invocation_resume_triggers_new_invocation(): + """Test that setting resume on AfterInvocationEvent re-invokes the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + ) + + resume_count = 0 + + async def resume_once(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count == 0: + resume_count += 1 + event.resume = "continue" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + result = agent("start") + + # Agent should have been invoked twice + assert resume_count == 1 + assert result.message["content"][0]["text"] == "Second response" + # 4 messages: user1, assistant1, user2 (resume), assistant2 + assert len(agent.messages) == 4 + assert agent.messages[0]["content"][0]["text"] == "start" + assert agent.messages[2]["content"][0]["text"] == "continue" + + +def test_after_invocation_resume_none_does_not_loop(): + """Test that resume=None (default) does not re-invoke the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Only response"}]}, + ] + ) + + call_count = 0 + + async def no_resume(event: AfterInvocationEvent): + nonlocal call_count + call_count += 1 + # Don't set resume - should remain None + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, no_resume) + + result = agent("hello") + + assert call_count == 1 + assert result.message["content"][0]["text"] == "Only response" + + +def test_after_invocation_resume_fires_before_invocation_event(): + """Test that resume triggers BeforeInvocationEvent on each iteration.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + ] + ) + + before_invocation_count = 0 + after_invocation_count = 0 + + async def count_before(event: BeforeInvocationEvent): + nonlocal before_invocation_count + before_invocation_count += 1 + + async def resume_once(event: AfterInvocationEvent): + nonlocal after_invocation_count + after_invocation_count += 1 + if after_invocation_count == 1: + event.resume = "next" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, count_before) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + agent("start") + + # BeforeInvocationEvent should fire for both the initial and resumed invocation + assert before_invocation_count == 2 + assert after_invocation_count == 2 + + +def test_after_invocation_resume_multiple_times(): + """Test that resume can chain multiple re-invocations.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + ) + + resume_count = 0 + + async def resume_twice(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count < 2: + resume_count += 1 + event.resume = f"iteration {resume_count + 1}" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_twice) + + result = agent("iteration 1") + + assert resume_count == 2 + assert result.message["content"][0]["text"] == "Response 3" + # 6 messages: 3 user + 3 assistant + assert len(agent.messages) == 6