From e40696f3625e5e7db593f49edc4f5f66e72a2c4e Mon Sep 17 00:00:00 2001 From: Jay Goyani Date: Thu, 26 Feb 2026 21:24:39 +0000 Subject: [PATCH] feat: add CancellationToken for graceful agent execution cancellation Implement CancellationToken to enable graceful cancellation of agent execution from external contexts (e.g., web requests, background threads). Key features: - Thread-safe CancellationToken class with cancel() and is_cancelled() methods - Agent accepts optional cancellation_token parameter at creation time - Four cancellation checkpoints in event loop: 1. Start of event loop cycle 2. Before model execution 3. During model response streaming 4. Before tool execution - New 'cancelled' stop reason in StopReason type - Token automatically added to invocation_state for event loop access Implementation details: - CancellationToken uses threading.Lock for thread safety - Zero usage/metrics returned when cancelled (no model execution occurred) - Cancellation detected at checkpoints returns proper empty message structure - Token shared by reference across packages in same process Tests included: - Unit tests for CancellationToken thread safety and behavior - Unit tests for agent cancellation at different checkpoints - Integration tests with real model providers (require credentials) --- src/strands/__init__.py | 2 + src/strands/agent/agent.py | 15 + src/strands/event_loop/event_loop.py | 67 ++++ src/strands/event_loop/streaming.py | 17 +- src/strands/plugins/plugin.py | 2 +- src/strands/types/__init__.py | 3 +- src/strands/types/cancellation.py | 55 +++ src/strands/types/event_loop.py | 2 + .../strands/agent/test_agent_cancellation.py | 283 +++++++++++++ tests/strands/types/test_cancellation.py | 143 +++++++ tests_integ/test_cancellation.py | 377 ++++++++++++++++++ 11 files changed, 962 insertions(+), 4 deletions(-) create mode 100644 src/strands/types/cancellation.py create mode 100644 tests/strands/agent/test_agent_cancellation.py create mode 100644 tests/strands/types/test_cancellation.py create mode 100644 tests_integ/test_cancellation.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..b43442c4f 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -6,12 +6,14 @@ from .event_loop._retry import ModelRetryStrategy from .plugins import Plugin from .tools.decorator import tool +from .types.cancellation import CancellationToken from .types.tools import ToolContext __all__ = [ "Agent", "AgentBase", "agent", + "CancellationToken", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ebead3b7d..4dc50a491 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -29,6 +29,7 @@ from ..event_loop._retry import ModelRetryStrategy from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content +from ..types.cancellation import CancellationToken if TYPE_CHECKING: from ..tools import ToolProvider @@ -135,6 +136,7 @@ def __init__( tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, + cancellation_token: CancellationToken | None = None, ): """Initialize the Agent with the specified configuration. @@ -201,6 +203,12 @@ def __init__( Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided only for advanced use cases where the caller understands the risks. + cancellation_token: Optional token for cancelling agent execution. + When provided, the agent will check this token at strategic checkpoints during execution + (before model calls, during streaming, before tool execution) and stop gracefully if + cancellation is requested. The token can be cancelled from external contexts (other threads, + web requests, etc.) by calling token.cancel(). + Defaults to None (no cancellation support). Raises: ValueError: If agent id contains path separators. @@ -240,6 +248,9 @@ def __init__( self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory + # Store cancellation token for graceful termination + self.cancellation_token = cancellation_token + self.tool_registry = ToolRegistry() # Process tool list if provided @@ -724,6 +735,10 @@ async def stream_async( if invocation_state is not None: merged_state = invocation_state + # Add cancellation token to invocation state if provided + if self.cancellation_token is not None: + merged_state["cancellation_token"] = self.cancellation_token + callback_handler = self.callback_handler if kwargs: callback_handler = kwargs.get("callback_handler", self.callback_handler) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3113ddb79..00a306492 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -33,6 +33,7 @@ TypedEvent, ) from ..types.content import Message, Messages +from ..types.event_loop import Metrics, Usage from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -55,6 +56,23 @@ MAX_DELAY = 240 # 4 minutes +def _should_cancel(invocation_state: dict[str, Any]) -> bool: + """Check if cancellation has been requested. + + This helper function checks the cancellation token in the invocation state + and returns True if cancellation has been requested. It's called at strategic + checkpoints throughout the event loop to enable graceful termination. + + Args: + invocation_state: Invocation state containing optional cancellation token. + + Returns: + True if cancellation has been requested, False otherwise. + """ + token = invocation_state.get("cancellation_token") + return token.is_cancelled() if token else False + + def _has_tool_use_in_latest_message(messages: "Messages") -> bool: """Check if the latest message contains any ToolUse content blocks. @@ -129,6 +147,22 @@ async def event_loop_cycle( yield StartEvent() yield StartEventLoopEvent() + # CHECKPOINT 1: Check for cancellation at start of event loop cycle + # This allows cancellation before any model or tool execution begins + if _should_cancel(invocation_state): + logger.debug( + "event_loop_cycle_id=<%s> | cancellation detected at cycle start", + invocation_state.get("event_loop_cycle_id"), + ) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) + yield EventLoopStopEvent( + "cancelled", + {"role": "assistant", "content": []}, + agent.event_loop_metrics, + invocation_state["request_state"], + ) + return + # Create tracer span for this event loop cycle tracer = get_tracer() cycle_span = tracer.start_event_loop_cycle_span( @@ -307,6 +341,23 @@ async def _handle_model_execution( # Retry loop - actual retry logic is handled by retry_strategy hook # Hooks control when to stop retrying via the event.retry flag while True: + # CHECKPOINT 2: Check for cancellation before model call + # This prevents unnecessary model invocations when cancellation is requested + if _should_cancel(invocation_state): + logger.debug( + "model_id=<%s> | cancellation detected before model call", + agent.model.config.get("model_id") if hasattr(agent.model, "config") else None, + ) + # Return cancelled stop reason with empty message and zero usage/metrics + # since no model execution occurred + yield ModelStopReason( + stop_reason="cancelled", + message={"role": "assistant", "content": []}, + usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + metrics=Metrics(latencyMs=0), + ) + return + model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, @@ -465,6 +516,22 @@ async def _handle_tool_execution( tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] interrupts = [] + + # CHECKPOINT 4: Check for cancellation before tool execution + # This prevents tool execution when cancellation is requested + if _should_cancel(invocation_state): + logger.debug("tool_count=<%d> | cancellation detected before tool execution", len(tool_uses)) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "cancelled", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + ) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + return + tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index b157f740e..ebfeca7b5 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -22,6 +22,7 @@ ToolUseStreamEvent, TypedEvent, ) +from ..types.cancellation import CancellationToken from ..types.citations import CitationsContentBlock from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.streaming import ( @@ -368,13 +369,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non async def process_stream( - chunks: AsyncIterable[StreamEvent], start_time: float | None = None + chunks: AsyncIterable[StreamEvent], + start_time: float | None = None, + cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. start_time: Time when the model request is initiated + cancellation_token: Optional token to check for cancellation during streaming. Yields: The reason for stopping, the constructed message, and the usage metrics. @@ -395,6 +399,14 @@ async def process_stream( metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # CHECKPOINT 3: Check for cancellation before processing stream chunks + # This allows cancellation during model response streaming + if cancellation_token and cancellation_token.is_cancelled(): + logger.debug("cancellation detected during stream processing") + # Return cancelled stop reason with current state + yield ModelStopReason(stop_reason="cancelled", message=state["message"], usage=usage, metrics=metrics) + return + # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() @@ -463,5 +475,6 @@ async def stream_messages( invocation_state=invocation_state, ) - async for event in process_stream(chunks, start_time): + cancellation_token = invocation_state.get("cancellation_token") if invocation_state else None + async for event in process_stream(chunks, start_time, cancellation_token): yield event diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index e9f35f112..32486a3f1 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -16,7 +16,7 @@ class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They can register hooks, modify agent attributes, or perform other + They can register hooks, modify agent attributes, or perform other setup tasks on an agent instance. Attributes: diff --git a/src/strands/types/__init__.py b/src/strands/types/__init__.py index 7eef60cb4..db1ba73cd 100644 --- a/src/strands/types/__init__.py +++ b/src/strands/types/__init__.py @@ -1,5 +1,6 @@ """SDK type definitions.""" +from .cancellation import CancellationToken from .collections import PaginatedList -__all__ = ["PaginatedList"] +__all__ = ["CancellationToken", "PaginatedList"] diff --git a/src/strands/types/cancellation.py b/src/strands/types/cancellation.py new file mode 100644 index 000000000..ecdbebd8e --- /dev/null +++ b/src/strands/types/cancellation.py @@ -0,0 +1,55 @@ +"""Cancellation token types for graceful agent termination.""" + +import threading + + +class CancellationToken: + """Thread-safe cancellation token for graceful agent termination. + + This token can be used to signal cancellation requests from any thread + and checked synchronously during agent execution. When cancelled, the + agent will stop processing and yield a stop event with interrupt reasoning. + + Example: + ```python + token = CancellationToken() + + # In another thread or external system + token.cancel() + + # In agent execution + if token.is_cancelled(): + # Stop processing + pass + ``` + + Note: + This is a minimal implementation focused on cancellation signaling. + Callback registration for resource cleanup can be added in a future + phase if resource cleanup use cases emerge. + """ + + def __init__(self) -> None: + """Initialize a new cancellation token.""" + self._cancelled = False + self._lock = threading.Lock() + + def cancel(self) -> None: + """Signal cancellation request. + + This method is thread-safe and can be called from any thread. + Multiple calls to cancel() are safe and idempotent. + """ + with self._lock: + self._cancelled = True + + def is_cancelled(self) -> bool: + """Check if cancellation has been requested. + + This method is thread-safe and can be called from any thread. + + Returns: + True if cancellation has been requested, False otherwise. + """ + with self._lock: + return self._cancelled diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2a7ad344e..87bd0782b 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -37,6 +37,7 @@ class Metrics(TypedDict, total=False): StopReason = Literal[ + "cancelled", "content_filtered", "end_turn", "guardrail_intervened", @@ -47,6 +48,7 @@ class Metrics(TypedDict, total=False): ] """Reason for the model ending its response generation. +- "cancelled": Agent execution was cancelled via CancellationToken - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py new file mode 100644 index 000000000..8a1bd3370 --- /dev/null +++ b/tests/strands/agent/test_agent_cancellation.py @@ -0,0 +1,283 @@ +"""Tests for agent cancellation functionality.""" + +import asyncio +import time + +import pytest + +from strands import Agent, CancellationToken +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Default agent response for simple tests +DEFAULT_RESPONSE = { + "role": "assistant", + "content": [{"text": "Hello! How can I help you?"}], +} + + +@pytest.mark.asyncio +async def test_agent_cancellation_before_model_call(): + """Test cancellation before model call starts. + + This test verifies that when a cancellation token is cancelled before + the agent starts processing, the agent immediately stops with a + 'cancelled' stop reason without making any model calls. + """ + token = CancellationToken() + agent = Agent( + model=MockedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel immediately before invocation + token.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + # When cancelled, we return an empty assistant message structure + assert result.message == {"role": "assistant", "content": []} + + +@pytest.mark.asyncio +async def test_agent_cancellation_during_execution(): + """Test cancellation during agent execution. + + This test verifies that when a cancellation token is cancelled while + the agent is executing, the agent detects the cancellation at the next + checkpoint and stops gracefully with a 'cancelled' stop reason. + """ + token = CancellationToken() + + # Create a model provider that simulates a delay + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + # Add a small delay before streaming + await asyncio.sleep(0.1) + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent( + model=DelayedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel after a short delay (during execution) + async def cancel_after_delay(): + await asyncio.sleep(0.05) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + result = await agent.invoke_async("Hello") + await cancel_task + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancellation_with_tools(): + """Test cancellation during tool execution. + + This test verifies that when a cancellation token is cancelled while + tools are being executed, the agent stops gracefully and doesn't + execute remaining tools. + """ + from strands import tool + + tool_executed = [] + + @tool + def slow_tool(x: int) -> int: + """A slow tool that takes time to execute.""" + tool_executed.append(x) + time.sleep(0.1) + return x * 2 + + token = CancellationToken() + + # Create a response with tool use + tool_use_response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_1", + "name": "slow_tool", + "input": {"x": 5}, + } + } + ], + } + + agent = Agent( + model=MockedModelProvider([tool_use_response]), + tools=[slow_tool], + cancellation_token=token, + ) + + # Cancel during tool execution + async def cancel_after_delay(): + await asyncio.sleep(0.05) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + result = await agent.invoke_async("Use the tool") + await cancel_task + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_no_cancellation_token(): + """Test that agent works normally without cancellation token. + + This test verifies that when no cancellation token is provided, + the agent executes normally and completes successfully. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "end_turn" + assert result.message["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_agent_cancellation_idempotent(): + """Test that multiple cancellations are safe. + + This test verifies that calling cancel() multiple times on the same + token doesn't cause any issues and the agent still stops gracefully. + """ + token = CancellationToken() + agent = Agent( + model=MockedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel multiple times + token.cancel() + token.cancel() + token.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancellation_from_different_thread(): + """Test cancellation from a different thread. + + This test verifies that the cancellation token can be cancelled from + a different thread (simulating a web request or external system) and + the agent will detect it and stop gracefully. + """ + import threading + + token = CancellationToken() + + # Create a model provider with delay + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + await asyncio.sleep(0.1) + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent( + model=DelayedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel from a different thread + def cancel_from_thread(): + time.sleep(0.05) + token.cancel() + + cancel_thread = threading.Thread(target=cancel_from_thread) + cancel_thread.start() + + result = await agent.invoke_async("Hello") + + cancel_thread.join() + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancellation_shared_token(): + """Test that multiple agents can share the same cancellation token. + + This test verifies that when multiple agents share the same cancellation + token, cancelling the token affects all agents using it. + """ + token = CancellationToken() + + agent1 = Agent( + model=MockedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + agent2 = Agent( + model=MockedModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel the shared token + token.cancel() + + result1 = await agent1.invoke_async("Hello from agent 1") + result2 = await agent2.invoke_async("Hello from agent 2") + + assert result1.stop_reason == "cancelled" + assert result2.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancellation_streaming(): + """Test cancellation during streaming response. + + This test verifies that cancellation works correctly when using + the streaming API (stream_async). + """ + token = CancellationToken() + + # Create a model provider that streams slowly + class SlowStreamingModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + # Stream with delays between chunks + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + + # Stream multiple chunks with delays + for i in range(10): + await asyncio.sleep(0.05) + yield {"contentBlockDelta": {"delta": {"text": f"chunk {i} "}}} + + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + agent = Agent( + model=SlowStreamingModelProvider([DEFAULT_RESPONSE]), + cancellation_token=token, + ) + + # Cancel after receiving a few chunks + async def cancel_after_delay(): + await asyncio.sleep(0.15) # Let a few chunks through + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + events = [] + async for event in agent.stream_async("Hello"): + events.append(event) + if event.get("result"): + break + + await cancel_task + + # Find the result event + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" diff --git a/tests/strands/types/test_cancellation.py b/tests/strands/types/test_cancellation.py new file mode 100644 index 000000000..8d5d56482 --- /dev/null +++ b/tests/strands/types/test_cancellation.py @@ -0,0 +1,143 @@ +"""Tests for CancellationToken.""" + +import threading +import time + +from strands.types import CancellationToken + + +def test_cancellation_token_initial_state(): + """Test that token starts in non-cancelled state. + + Why: We need to ensure the default state is False so agents + don't immediately stop when created with a token. + """ + token = CancellationToken() + assert not token.is_cancelled() + + +def test_cancellation_token_cancel(): + """Test that cancel() sets cancelled state. + + Why: This is the core functionality - when cancel() is called, + the token must transition to cancelled state. + """ + token = CancellationToken() + token.cancel() + assert token.is_cancelled() + + +def test_cancellation_token_idempotent(): + """Test that multiple cancel() calls are safe. + + Why: In distributed systems, multiple cancel requests might arrive. + The token must handle this gracefully without errors or side effects. + """ + token = CancellationToken() + token.cancel() + token.cancel() + token.cancel() + assert token.is_cancelled() + + +def test_cancellation_token_thread_safety(): + """Test that token is thread-safe. + + Why: The token will be accessed from multiple threads: + - Main thread: agent checking is_cancelled() + - Background thread: poller calling cancel() + + This test ensures no race conditions occur. + """ + token = CancellationToken() + results = [] + + def cancel_from_thread(): + time.sleep(0.01) # Small delay to ensure check_from_thread starts first + token.cancel() + + def check_from_thread(): + for _ in range(100): + results.append(token.is_cancelled()) + time.sleep(0.001) + + t1 = threading.Thread(target=cancel_from_thread) + t2 = threading.Thread(target=check_from_thread) + + t2.start() + t1.start() + + t1.join() + t2.join() + + # Should have some False and some True values + # This proves the state transition was visible across threads + assert False in results + assert True in results + + +def test_cancellation_token_multiple_threads_checking(): + """Test that multiple threads can check cancellation simultaneously. + + Why: In complex agents, multiple components might check cancellation + at the same time. This ensures thread-safe reads. + """ + token = CancellationToken() + results = [] + + def check_repeatedly(): + for _ in range(50): + results.append(token.is_cancelled()) + time.sleep(0.001) + + # Start multiple checker threads + threads = [threading.Thread(target=check_repeatedly) for _ in range(3)] + for t in threads: + t.start() + + # Cancel while threads are checking + time.sleep(0.025) + token.cancel() + + for t in threads: + t.join() + + # All threads should have seen both states + assert False in results + assert True in results + # No exceptions should have occurred + + +def test_cancellation_token_shared_reference(): + """Test that token works when shared across objects. + + Why: This simulates the real use case where the same token + is passed to both the agent and the poller. Changes made + through one reference must be visible through the other. + """ + token = CancellationToken() + + # Simulate agent holding reference + class FakeAgent: + def __init__(self, cancellation_token): + self.cancellation_token = cancellation_token + + # Simulate poller holding reference + class FakePoller: + def __init__(self, cancellation_token): + self.cancellation_token = cancellation_token + + agent = FakeAgent(cancellation_token=token) + poller = FakePoller(cancellation_token=token) + + # Verify they're the same object + assert agent.cancellation_token is token + assert poller.cancellation_token is token + assert agent.cancellation_token is poller.cancellation_token + + # Cancel through poller + poller.cancellation_token.cancel() + + # Agent should see the change + assert agent.cancellation_token.is_cancelled() + assert token.is_cancelled() diff --git a/tests_integ/test_cancellation.py b/tests_integ/test_cancellation.py new file mode 100644 index 000000000..46f8b8792 --- /dev/null +++ b/tests_integ/test_cancellation.py @@ -0,0 +1,377 @@ +"""Integration tests for cancellation with real model providers. + +These tests verify that cancellation works correctly with actual model providers +like Bedrock, Anthropic, OpenAI, etc. They require valid credentials and may +incur API costs. + +To run these tests: + hatch run test-integ tests_integ/test_cancellation.py +""" + +import asyncio +import os +import threading +import time + +import pytest + +from strands import Agent, CancellationToken, tool + +# Skip all tests if no model credentials are available +pytestmark = pytest.mark.skipif( + not any( + [ + os.getenv("AWS_REGION"), # Bedrock + os.getenv("ANTHROPIC_API_KEY"), # Anthropic + os.getenv("OPENAI_API_KEY"), # OpenAI + ] + ), + reason="No model provider credentials found", +) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_cancellation_with_bedrock(): + """Test cancellation with Amazon Bedrock model. + + This test verifies that cancellation works correctly with a real Bedrock + model. It starts a long-running request and cancels it mid-execution. + """ + from strands.models import BedrockModel + + token = CancellationToken() + agent = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + cancellation_token=token, + ) + + # Cancel after 2 seconds + async def cancel_after_delay(): + await asyncio.sleep(2.0) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + # Request a long response that should take more than 2 seconds + result = await agent.invoke_async( + "Write a detailed 1000-word essay about the history of space exploration, " + "including major milestones, key figures, and technological breakthroughs." + ) + + await cancel_task + + assert result.stop_reason == "cancelled" + # The message might be empty or partially complete + assert result.message is not None + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="Anthropic API key not available") +async def test_cancellation_with_anthropic(): + """Test cancellation with Anthropic Claude model. + + This test verifies that cancellation works correctly with the Anthropic + API. It starts a long-running request and cancels it mid-execution. + """ + from strands.models import AnthropicModel + + token = CancellationToken() + agent = Agent( + model=AnthropicModel("claude-3-haiku-20240307"), + cancellation_token=token, + ) + + # Cancel after 2 seconds + async def cancel_after_delay(): + await asyncio.sleep(2.0) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + # Request a long response + result = await agent.invoke_async( + "Write a detailed 1000-word essay about artificial intelligence, " + "covering its history, current applications, and future potential." + ) + + await cancel_task + + assert result.stop_reason == "cancelled" + assert result.message is not None + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not available") +async def test_cancellation_with_openai(): + """Test cancellation with OpenAI model. + + This test verifies that cancellation works correctly with the OpenAI + API. It starts a long-running request and cancels it mid-execution. + """ + from strands.models import OpenAIModel + + token = CancellationToken() + agent = Agent( + model=OpenAIModel("gpt-4o-mini"), + cancellation_token=token, + ) + + # Cancel after 2 seconds + async def cancel_after_delay(): + await asyncio.sleep(2.0) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + # Request a long response + result = await agent.invoke_async( + "Write a detailed 1000-word essay about quantum computing, " + "explaining the principles, current state, and potential applications." + ) + + await cancel_task + + assert result.stop_reason == "cancelled" + assert result.message is not None + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_cancellation_during_streaming_bedrock(): + """Test cancellation during streaming with Bedrock. + + This test verifies that cancellation works correctly when using the + streaming API with a real Bedrock model. + """ + from strands.models import BedrockModel + + token = CancellationToken() + agent = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + cancellation_token=token, + ) + + # Cancel after receiving some chunks + async def cancel_after_delay(): + await asyncio.sleep(1.5) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + events = [] + async for event in agent.stream_async( + "Write a detailed story about a space adventure. Make it at least 500 words long." + ): + events.append(event) + if event.get("result"): + break + + await cancel_task + + # Find the result event + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_cancellation_with_tools_bedrock(): + """Test cancellation during tool execution with Bedrock. + + This test verifies that cancellation works correctly when the agent + is executing tools with a real Bedrock model. + """ + from strands.models import BedrockModel + + @tool + def slow_calculation(x: int, y: int) -> int: + """Perform a slow calculation that takes time. + + Args: + x: First number + y: Second number + + Returns: + The sum of x and y + """ + time.sleep(2) # Simulate slow operation + return x + y + + @tool + def another_calculation(a: int, b: int) -> int: + """Another slow calculation. + + Args: + a: First number + b: Second number + + Returns: + The product of a and b + """ + time.sleep(2) # Simulate slow operation + return a * b + + token = CancellationToken() + agent = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + tools=[slow_calculation, another_calculation], + cancellation_token=token, + ) + + # Cancel after 3 seconds (should be during tool execution) + async def cancel_after_delay(): + await asyncio.sleep(3.0) + token.cancel() + + cancel_task = asyncio.create_task(cancel_after_delay()) + + result = await agent.invoke_async( + "Please use the slow_calculation tool to add 5 and 10, then use another_calculation to multiply 3 and 7." + ) + + await cancel_task + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_cancellation_from_different_thread_bedrock(): + """Test cancellation from a different thread with Bedrock. + + This test simulates a real-world scenario where cancellation is triggered + from a different thread (e.g., a web request handler) while the agent + is executing in another thread. + """ + from strands.models import BedrockModel + + token = CancellationToken() + agent = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + cancellation_token=token, + ) + + # Cancel from a different thread after 2 seconds + def cancel_from_thread(): + time.sleep(2.0) + token.cancel() + + cancel_thread = threading.Thread(target=cancel_from_thread) + cancel_thread.start() + + result = await agent.invoke_async( + "Write a comprehensive guide about machine learning, " + "covering supervised learning, unsupervised learning, and deep learning. " + "Make it at least 800 words." + ) + + cancel_thread.join() + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="Anthropic API key not available") +async def test_cancellation_before_model_call_anthropic(): + """Test cancellation before model call with Anthropic. + + This test verifies that when cancellation is requested before the model + is called, the agent stops immediately without making any API calls. + """ + from strands.models import AnthropicModel + + token = CancellationToken() + agent = Agent( + model=AnthropicModel("claude-3-haiku-20240307"), + cancellation_token=token, + ) + + # Cancel immediately before invocation + token.cancel() + + result = await agent.invoke_async("Hello, how are you?") + + assert result.stop_reason == "cancelled" + # Should not have made any API calls, so message should be empty + assert result.message == {} + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not available") +async def test_cancellation_idempotent_openai(): + """Test that multiple cancellations are safe with OpenAI. + + This test verifies that calling cancel() multiple times doesn't cause + any issues with a real model provider. + """ + from strands.models import OpenAIModel + + token = CancellationToken() + agent = Agent( + model=OpenAIModel("gpt-4o-mini"), + cancellation_token=token, + ) + + # Cancel multiple times + token.cancel() + token.cancel() + token.cancel() + + result = await agent.invoke_async("Tell me a short joke.") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_cancellation_shared_token_bedrock(): + """Test that multiple agents can share the same cancellation token. + + This test verifies that when multiple agents share a cancellation token, + cancelling it affects all agents using it. + """ + from strands.models import BedrockModel + + token = CancellationToken() + + agent1 = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + cancellation_token=token, + ) + + agent2 = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + cancellation_token=token, + ) + + # Cancel the shared token + token.cancel() + + result1 = await agent1.invoke_async("Write a short poem about the ocean.") + result2 = await agent2.invoke_async("Write a short poem about mountains.") + + assert result1.stop_reason == "cancelled" + assert result2.stop_reason == "cancelled" + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available") +async def test_agent_without_cancellation_token_bedrock(): + """Test that agent works normally without cancellation token. + + This test verifies that when no cancellation token is provided, + the agent executes normally with a real model. + """ + from strands.models import BedrockModel + + agent = Agent( + model=BedrockModel("anthropic.claude-3-haiku-20240307-v1:0"), + ) + + result = await agent.invoke_async("Say hello in exactly 5 words.") + + assert result.stop_reason == "end_turn" + assert result.message["role"] == "assistant" + assert len(result.message["content"]) > 0