diff --git a/AGENTS.md b/AGENTS.md index 6a5765a94..10a66fcd7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,10 +124,12 @@ strands-agents/ │ │ │ ├── hooks/ # Event hooks system │ │ ├── events.py # Hook event definitions -│ │ └── registry.py # Hook registration +│ │ ├── registry.py # Hook registration +│ │ └── _type_inference.py # Event type inference from type hints │ │ │ ├── plugins/ # Plugin system -│ │ ├── plugin.py # Plugin definition +│ │ ├── plugin.py # Plugin base class +│ │ ├── decorator.py # @hook decorator │ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 9dac9ba74..214118d4f 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Any from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....plugins.plugin import Plugin +from ....plugins import Plugin, hook from ....types.content import Message from ....types.streaming import StopReason from ....types.tools import ToolUse @@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non Args: context_providers: List of context providers for context updates """ + super().__init__() self.steering_context = SteeringContext() self._context_callbacks = [] @@ -87,13 +88,8 @@ def init_agent(self, agent: "Agent") -> None: for callback in self._context_callbacks: agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) - # Register tool steering guidance - agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent) - - # Register model steering guidance - agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent) - - async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) @@ -133,7 +129,8 @@ def _handle_tool_steering_action( else: raise ValueError(f"Unknown steering action type for tool call: {action}") - async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: """Provide steering guidance for model response.""" logger.debug("providing model steering guidance") diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py new file mode 100644 index 000000000..aba7d1164 --- /dev/null +++ b/src/strands/hooks/_type_inference.py @@ -0,0 +1,78 @@ +"""Utility for inferring event types from callback type hints.""" + +import inspect +import logging +import types +from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from .registry import HookCallback, TEvent + +logger = logging.getLogger(__name__) + + +def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]": + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + Args: + callback: The callback function to inspect. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + # Import here to avoid circular dependency + from .registry import BaseHookEvent + + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # Skip 'self' parameter for methods + first_param = params[0] + if first_param.name == "self" and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast("type[TEvent]", arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast("type[TEvent]", type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 886ea5644..8b284b0c2 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,24 +9,12 @@ import inspect import logging -import types from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - TypeVar, - Union, - cast, - get_args, - get_origin, - get_type_hints, - runtime_checkable, -) +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException +from ._type_inference import infer_event_types if TYPE_CHECKING: from ..agent import Agent @@ -225,7 +213,7 @@ def multi_handler(event): resolved_event_types = self._validate_event_type_list(event_type) elif event_type is None: # Infer event type(s) from callback type hints - resolved_event_types = self._infer_event_types(callback) + resolved_event_types = infer_event_types(callback) else: # Single event type provided explicitly resolved_event_types = [event_type] @@ -261,67 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ validated.append(et) return validated - def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError( - "callback has no parameters | cannot infer event type, please provide event_type explicitly" - ) - - first_param = params[0] - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) - def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index aa1491545..c4b7c72c7 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,25 +1,13 @@ """Plugin system for extending agent functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through a standardized initialization pattern. - -Example Usage: - ```python - from strands.plugins import Plugin - - class LoggingPlugin(Plugin): - name = "logging" - - def init_agent(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) - - def on_model_call(self, event: BeforeModelCallEvent) -> None: - print(f"Model called for {event.agent.name}") - ``` +extend agent behavior through automatic hook and tool registration. """ +from .decorator import hook from .plugin import Plugin __all__ = [ "Plugin", + "hook", ] diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py new file mode 100644 index 000000000..fc6f75e5b --- /dev/null +++ b/src/strands/plugins/decorator.py @@ -0,0 +1,69 @@ +"""Hook decorator for Plugin methods. + +Marks methods as hook callbacks for automatic registration when the plugin +is attached to an agent. Infers event types from type hints and supports +union types for multiple events. + +Example: + ```python + class MyPlugin(Plugin): + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(event) + ``` +""" + +from collections.abc import Callable +from typing import Generic, cast, overload + +from ..hooks._type_inference import infer_event_types +from ..hooks.registry import HookCallback, TEvent + + +class _WrappedHookCallable(HookCallback, Generic[TEvent]): + """Wrapped version of HookCallback that includes a `_hook_event_types` attribute.""" + + _hook_event_types: list[type[TEvent]] + + +# Handle @hook +@overload +def hook(__func: HookCallback) -> _WrappedHookCallable: ... + + +# Handle @hook() +@overload +def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ... + + +def hook( + func: HookCallback | None = None, +) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]: + """Mark a method as a hook callback for automatic registration. + + Infers event type from the callback's type hint. Supports union types + for multiple events. Can be used as @hook or @hook(). + + Args: + func: The function to decorate. + + Returns: + The decorated function with hook metadata. + + Raises: + ValueError: If event type cannot be inferred from type hints. + """ + + def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]: + # Infer event types from type hints + event_types: list[type[TEvent]] = infer_event_types(f) + + # Store hook metadata on the function + f_wrapped = cast(_WrappedHookCallable, f) + f_wrapped._hook_event_types = event_types + + return f_wrapped + + if func is None: + return decorator + return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index e9f35f112..b670de297 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -1,34 +1,70 @@ """Plugin base class for extending agent functionality. This module defines the Plugin base class, which provides a composable way to -add behavior changes to agents through a standardized initialization pattern. +add behavior changes to agents through automatic hook and tool registration. """ +import logging from abc import ABC, abstractmethod from collections.abc import Awaitable from typing import TYPE_CHECKING +from ..hooks.registry import HookCallback +from ..tools.decorator import DecoratedFunctionTool + if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) + class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They can register hooks, modify agent attributes, or perform other - setup tasks on an agent instance. + They support automatic discovery and registration of methods decorated + with @hook and @tool decorators. Attributes: - name: A stable string identifier for the plugin + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the agent, auto-discovered from @hook decorated methods during __init__ + tools: Tools attached to the agent, auto-discovered from @tool decorated methods during __init__ - Example: + Example using decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + from strands import tool + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @tool + def my_tool(self, param: str) -> str: + '''A tool that does something.''' + return f"Result: {param}" + ``` + + Note: Decorated methods are registered in declaration order, with parent + class methods registered before child class methods. If a child overrides + a parent's decorated method, only the child's version is registered. + + Example with custom initialization: ```python class MyPlugin(Plugin): name = "my-plugin" def init_agent(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + # Custom initialization logic - no super() needed + # Decorated hooks/tools are auto-registered by the plugin registry + agent.add_hook(self.custom_hook) + + def custom_hook(self, event: BeforeModelCallEvent): + print(event) ``` """ @@ -38,11 +74,59 @@ def name(self) -> str: """A stable string identifier for the plugin.""" ... - @abstractmethod + def __init__(self) -> None: + """Initialize the plugin and discover decorated methods. + + Scans the class for methods decorated with @hook and @tool and stores + references for later registration when the plugin is attached to an agent. + """ + self._hooks: list[HookCallback] = [] + self._tools: list[DecoratedFunctionTool] = [] + self._discover_decorated_methods() + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + @property + def tools(self) -> list[DecoratedFunctionTool]: + """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" + return self._tools + + def _discover_decorated_methods(self) -> None: + """Scan class for @hook and @tool decorated methods in declaration order.""" + seen: set[str] = set() + # Walk MRO so parent class hooks come first, child overrides win + for cls in reversed(type(self).__mro__): + for name in cls.__dict__: + if name in seen: + continue + seen.add(name) + + # Get the bound method from self + try: + bound = getattr(self, name) + except Exception: + continue + + # Check for @hook decorated methods + if hasattr(bound, "_hook_event_types") and callable(bound): + self._hooks.append(bound) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) + + # Check for @tool decorated methods (DecoratedFunctionTool instances) + if isinstance(bound, DecoratedFunctionTool): + self._tools.append(bound) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) + def init_agent(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the agent instance. + Override this method to add custom initialization logic. Decorated + hooks and tools are automatically registered by the plugin registry. + Args: agent: The agent instance to initialize. """ - ... + return None diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index 3b8a0a45f..a75858680 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -24,6 +24,11 @@ class _PluginRegistry: The _PluginRegistry tracks plugins that have been initialized with an agent, providing methods to add plugins and invoke their initialization. + The registry handles: + 1. Calling the plugin's init_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the agent + 3. Auto-registering discovered @tool decorated methods with the agent + Example: ```python registry = _PluginRegistry(agent) @@ -31,7 +36,12 @@ class _PluginRegistry: class MyPlugin(Plugin): name = "my-plugin" + @hook + def on_event(self, event: BeforeModelCallEvent): + pass # Auto-registered by registry + def init_agent(self, agent: Agent) -> None: + # Custom logic only - no super() needed pass plugin = MyPlugin() @@ -51,7 +61,12 @@ def __init__(self, agent: "Agent") -> None: def add_and_init(self, plugin: Plugin) -> None: """Add and initialize a plugin with the agent. - This method registers the plugin and calls its init_agent method. + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the agent's hook registry + 4. Auto-registers all discovered @tool methods with the agent's tool registry + Handles both sync and async init_agent implementations automatically. Args: @@ -66,8 +81,51 @@ def add_and_init(self, plugin: Plugin) -> None: logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) self._plugins[plugin.name] = plugin + # Call user's init_agent for custom initialization if inspect.iscoroutinefunction(plugin.init_agent): async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_agent) run_async(lambda: async_plugin_init(self._agent)) else: plugin.init_agent(self._agent) + + # Auto-register discovered hooks with the agent's hook registry + self._register_hooks(plugin) + + # Auto-register discovered tools with the agent's tool registry + self._register_tools(plugin) + + def _register_hooks(self, plugin: Plugin) -> None: + """Register all discovered hooks from the plugin with the agent. + + Warns if a hook callback is already registered for an event type, + which can happen when init_agent() manually registers a hook that + is also decorated with @hook. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._agent.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + def _register_tools(self, plugin: Plugin) -> None: + """Register all discovered tools from the plugin with the agent. + + Args: + plugin: The plugin whose tools should be registered. + """ + if plugin.tools: + self._agent.tool_registry.process_tools(list(plugin.tools)) + for tool in plugin.tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + plugin.name, + tool.tool_name, + ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 55de68ff1..0491c8686 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -14,7 +14,7 @@ from pydantic import BaseModel import strands -from strands import Agent, ToolContext +from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager @@ -2625,6 +2625,8 @@ def test_agent_plugins_sync_initialization(): """Test that plugins with sync init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "test-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] plugin_mock.init_agent = unittest.mock.Mock() agent = Agent( @@ -2639,6 +2641,8 @@ def test_agent_plugins_async_initialization(): """Test that plugins with async init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "async-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] plugin_mock.init_agent = unittest.mock.AsyncMock() agent = Agent( @@ -2655,10 +2659,14 @@ def test_agent_plugins_multiple_in_order(): plugin1 = unittest.mock.Mock() plugin1.name = "plugin1" + plugin1.hooks = [] + plugin1.tools = [] plugin1.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) plugin2 = unittest.mock.Mock() plugin2.name = "plugin2" + plugin2.hooks = [] + plugin2.tools = [] plugin2.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) Agent( @@ -2673,7 +2681,7 @@ def test_agent_plugins_can_register_hooks(): """Test that plugins can register hooks during initialization.""" hook_called = [] - class TestPlugin: + class TestPlugin(Plugin): name = "hook-plugin" def init_agent(self, agent): diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 90064ea98..1f247120a 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,5 +1,6 @@ """Unit tests for steering handler base class.""" +import inspect from unittest.mock import AsyncMock, Mock import pytest @@ -8,6 +9,7 @@ from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from strands.hooks.registry import HookRegistry from strands.plugins import Plugin @@ -38,15 +40,24 @@ def test_steering_handler_is_plugin(): def test_init_agent(): - """Test init_agent registers hooks on agent.""" + """Test init_agent with plugin registry registers hooks on agent.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Verify hooks were registered (tool and model steering hooks) + # Verify hooks were registered (tool and model steering hooks via @hook decorator) assert agent.add_hook.call_count >= 2 - agent.add_hook.assert_any_call(handler._provide_tool_steering_guidance, BeforeToolCallEvent) + # Check that the decorated hook methods were registered + assert BeforeToolCallEvent in agent.hooks._registered_callbacks + assert AfterModelCallEvent in agent.hooks._registered_callbacks def test_steering_context_initialization(): @@ -86,7 +97,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -105,7 +116,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." @@ -126,7 +137,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -145,7 +156,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -165,7 +176,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) def test_init_agent_override(): @@ -218,62 +229,86 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): - """Test that handler registers hooks from context callbacks.""" + """Test that handler registers hooks from context callbacks via registry.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register hooks for context callback and steering guidance + # Should register hooks for context callback (via init_agent) and steering guidance (via @hook) + # init_agent registers context callbacks manually, @hook decorated methods are auto-registered assert agent.add_hook.call_count >= 2 - # Check that BeforeToolCallEvent was registered - call_args = [call[0] for call in agent.add_hook.call_args_list] - event_types = [args[1] for args in call_args] + # Check that BeforeToolCallEvent was registered (both context callback and steering guidance) + assert BeforeToolCallEvent in agent.hooks._registered_callbacks - assert BeforeToolCallEvent in event_types - -def test_context_callbacks_receive_steering_context(): +@pytest.mark.asyncio +async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Get the registered callback for BeforeToolCallEvent - before_callback = None - for call in agent.add_hook.call_args_list: - if call[0][1] == BeforeToolCallEvent: - before_callback = call[0][0] - break + # Get the registered callbacks for BeforeToolCallEvent + callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) + assert len(callbacks) > 0 - assert before_callback is not None - - # Create a mock event and call the callback + # The context callback is wrapped in a lambda, so we just call all callbacks + # and check if the steering context was updated event = Mock(spec=BeforeToolCallEvent) event.tool_use = {"name": "test_tool", "input": {}} - # The callback should execute without error and update the steering context - before_callback(event) + # Call all callbacks, handling both sync and async + for cb in callbacks: + try: + result = await cb(event) + if inspect.iscoroutine(result): + await result + except Exception: + pass # Some callbacks might be async or have other requirements - # Verify the steering context was updated + # Verify the steering context was updated by at least one callback assert handler.steering_context.data.get("test_key") == "test_value" def test_multiple_context_callbacks_registered(): - """Test that multiple context callbacks are registered.""" + """Test that multiple context callbacks are registered via registry.""" + from strands.plugins.registry import _PluginRegistry + callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register one callback for each context provider plus tool and model steering guidance + # Should register: + # - 2 callbacks for context providers (via init_agent manual registration) + # - 2 for steering guidance (via @hook decorator auto-registration) expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert agent.add_hook.call_count >= expected_calls @@ -310,7 +345,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should not set retry for Proceed assert event.retry is False @@ -334,7 +369,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should set retry flag assert event.retry is True @@ -362,7 +397,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event = Mock(spec=AfterModelCallEvent) event.stop_response = None - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # steer_after_model should not have been called assert handler.steer_called is False @@ -386,7 +421,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -407,7 +442,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -429,7 +464,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.retry = False # Should not raise, just return early - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # retry should not be set since exception occurred assert event.retry is False @@ -449,7 +484,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) # Should not raise, just return early - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # cancel_tool should not be set since exception occurred assert not event.cancel_tool @@ -486,11 +521,20 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_agent_registers_model_steering(): - """Test that init_agent registers model steering callback.""" + """Test that model steering hook is registered via plugin registry.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() - - handler.init_agent(agent) - - # Verify model steering hook was registered - agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Verify model steering hook was registered via @hook decorator + assert AfterModelCallEvent in agent.hooks._registered_callbacks + callbacks = agent.hooks._registered_callbacks[AfterModelCallEvent] + assert len(callbacks) == 1 diff --git a/tests/strands/plugins/test_hook_decorator.py b/tests/strands/plugins/test_hook_decorator.py new file mode 100644 index 000000000..520040c9d --- /dev/null +++ b/tests/strands/plugins/test_hook_decorator.py @@ -0,0 +1,232 @@ +"""Tests for the @hook decorator.""" + +import unittest.mock + +import pytest + +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, +) +from strands.plugins.decorator import hook + + +class TestHookDecoratorBasic: + """Tests for basic @hook decorator functionality.""" + + def test_hook_decorator_marks_method(self): + """Test that @hook marks a method with hook metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_with_parentheses(self): + """Test that @hook() syntax also works.""" + + @hook() + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_preserves_function_metadata(self): + """Test that @hook preserves the original function's metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + """Docstring for the hook.""" + pass + + assert on_before_model_call.__name__ == "on_before_model_call" + assert on_before_model_call.__doc__ == "Docstring for the hook." + + def test_hook_decorator_function_still_callable(self): + """Test that decorated function can still be called normally.""" + call_count = 0 + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + on_before_model_call(mock_event) + assert call_count == 1 + + +class TestHookDecoratorEventTypeInference: + """Tests for event type inference from type hints.""" + + def test_hook_infers_event_type_from_type_hint(self): + """Test that @hook infers event type from the first parameter's type hint.""" + + @hook + def handler(event: BeforeInvocationEvent): + pass + + assert BeforeInvocationEvent in handler._hook_event_types + + def test_hook_infers_different_event_types(self): + """Test that different event types are correctly inferred.""" + + @hook + def handler1(event: BeforeModelCallEvent): + pass + + @hook + def handler2(event: AfterModelCallEvent): + pass + + @hook + def handler3(event: AfterInvocationEvent): + pass + + assert BeforeModelCallEvent in handler1._hook_event_types + assert AfterModelCallEvent in handler2._hook_event_types + assert AfterInvocationEvent in handler3._hook_event_types + + +class TestHookDecoratorUnionTypes: + """Tests for union type support in @hook decorator.""" + + def test_hook_supports_union_types_with_pipe(self): + """Test that @hook supports union types using | syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_union_types_with_typing_union(self): + """Test that @hook supports Union[] syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_multiple_union_types(self): + """Test that @hook supports unions with more than two types.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + assert BeforeInvocationEvent in handler._hook_event_types + + +class TestHookDecoratorErrorHandling: + """Tests for error handling in @hook decorator.""" + + def test_hook_raises_error_without_type_hint(self): + """Test that @hook raises error when no type hint is provided.""" + with pytest.raises(ValueError, match="cannot infer event type"): + + @hook + def handler(event): + pass + + def test_hook_raises_error_with_non_hook_event_type(self): + """Test that @hook raises error when type hint is not a HookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def handler(event: str): + pass + + def test_hook_raises_error_with_none_in_union(self): + """Test that @hook raises error when union contains None.""" + with pytest.raises(ValueError, match="None is not a valid event type"): + + @hook + def handler(event: BeforeModelCallEvent | None): + pass + + +class TestHookDecoratorWithMethods: + """Tests for @hook decorator on class methods.""" + + def test_hook_works_on_instance_method(self): + """Test that @hook works correctly on instance methods.""" + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + pass + + instance = MyClass() + assert hasattr(instance.handler, "_hook_event_types") + assert BeforeModelCallEvent in instance.handler._hook_event_types + + def test_hook_instance_method_is_callable(self): + """Test that decorated instance method can be called.""" + call_count = 0 + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert call_count == 1 + + def test_hook_method_accesses_self(self): + """Test that decorated method can access self.""" + + class MyClass: + def __init__(self): + self.events_received = [] + + @hook + def handler(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert len(instance.events_received) == 1 + assert instance.events_received[0] is mock_event + + +class TestHookDecoratorAsync: + """Tests for async functions with @hook decorator.""" + + def test_hook_works_on_async_function(self): + """Test that @hook works on async functions.""" + + @hook + async def handler(event: BeforeModelCallEvent): + pass + + assert hasattr(handler, "_hook_event_types") + assert BeforeModelCallEvent in handler._hook_event_types + + @pytest.mark.asyncio + async def test_hook_async_function_is_callable(self): + """Test that decorated async function can be awaited.""" + call_count = 0 + + @hook + async def handler(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + await handler(mock_event) + assert call_count == 1 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py new file mode 100644 index 000000000..dab3e7210 --- /dev/null +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -0,0 +1,553 @@ +"""Tests for the Plugin base class with auto-discovery.""" + +import unittest.mock + +import pytest + +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.registry import _PluginRegistry +from strands.tools.decorator import tool + + +def _configure_mock_agent_with_hooks(): + """Helper to create a mock agent with working add_hook.""" + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.add_hook.side_effect = lambda callback, event_type=None: mock_agent.hooks.add_callback( + event_type, callback + ) + return mock_agent + + +class TestPluginBaseClass: + """Tests for Plugin base class basics.""" + + def test_plugin_is_class_not_protocol(self): + """Test that Plugin is now a class, not a Protocol.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + def test_plugin_requires_name_attribute(self): + """Test that Plugin subclass must have name attribute.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert plugin.name == "my-plugin" + + def test_plugin_name_as_property(self): + """Test that Plugin name can be a property.""" + + class MyPlugin(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = MyPlugin() + assert plugin.name == "property-plugin" + + +class TestPluginAutoDiscovery: + """Tests for automatic discovery of decorated methods.""" + + def test_plugin_discovers_hook_decorated_methods(self): + """Test that Plugin.__init__ discovers @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_model" + + def test_plugin_discovers_multiple_hooks(self): + """Test that Plugin discovers multiple @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + hook_names = {h.__name__ for h in plugin.hooks} + assert "hook1" in hook_names + assert "hook2" in hook_names + + def test_hooks_preserve_definition_order(self): + """Test that hooks are discovered in definition order, not alphabetical.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def z_last_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def a_first_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def m_middle_alphabetically(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 3 + # Should be in definition order, not alphabetical + assert plugin.hooks[0].__name__ == "z_last_alphabetically" + assert plugin.hooks[1].__name__ == "a_first_alphabetically" + assert plugin.hooks[2].__name__ == "m_middle_alphabetically" + + def test_plugin_discovers_tool_decorated_methods(self): + """Test that Plugin.__init__ discovers @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "my_tool" + + def test_plugin_discovers_both_hooks_and_tools(self): + """Test that Plugin discovers both @hook and @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert len(plugin.tools) == 1 + + def test_plugin_ignores_non_decorated_methods(self): + """Test that Plugin doesn't discover non-decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + def test_hooks_property_returns_list(self): + """Test that hooks property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert isinstance(plugin.hooks, list) + + def test_tools_property_returns_list(self): + """Test that tools property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert isinstance(plugin.tools, list) + + def test_hooks_can_be_filtered(self): + """Test that hooks list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + + # Filter out hook1 + plugin.hooks[:] = [h for h in plugin.hooks if h.__name__ != "hook1"] + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "hook2" + + def test_tools_can_be_filtered(self): + """Test that tools list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def tool1(self, param: str) -> str: + """Tool 1.""" + return param + + @tool + def tool2(self, param: str) -> str: + """Tool 2.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 2 + + # Filter out tool1 + plugin.tools[:] = [t for t in plugin.tools if t.tool_name != "tool1"] + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "tool2" + + +class TestPluginRegistryAutoRegistration: + """Tests for auto-registration via _PluginRegistry.""" + + def test_registry_registers_hooks_with_agent(self): + """Test that _PluginRegistry registers discovered hooks with agent.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_registry_registers_tools_with_agent(self): + """Test that _PluginRegistry adds discovered tools to agent's tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify tool was added to agent + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_registers_both_hooks_and_tools(self): + """Test that _PluginRegistry registers both hooks and tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify both registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_calls_init_agent_before_registration(self): + """Test that _PluginRegistry calls init_agent for custom logic.""" + init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal init_called + init_called = True + # Custom logic - no super() needed + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert init_called + # Verify auto-registration still happened + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginHookWithUnionTypes: + """Tests for Plugin hooks with union types.""" + + def test_registry_registers_hook_for_union_types(self): + """Test that hooks with union types are registered for all event types.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered for both event types + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginMultipleAgents: + """Tests for plugin reuse with multiple agents.""" + + def test_plugin_can_be_attached_to_multiple_agents(self): + """Test that the same plugin instance can be used with multiple agents.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + + mock_agent1 = _configure_mock_agent_with_hooks() + mock_agent2 = _configure_mock_agent_with_hooks() + + # Note: In practice, different registries would be used for each agent + # Here we simulate attaching to multiple agents directly + registry1 = _PluginRegistry(mock_agent1) + registry1.add_and_init(plugin) + + # Create new plugin instance for second agent (same class) + plugin2 = MyPlugin() + registry2 = _PluginRegistry(mock_agent2) + registry2.add_and_init(plugin2) + + # Verify both agents have the hook registered + assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent2.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginSubclassOverride: + """Tests for subclass overriding init_agent.""" + + def test_subclass_can_override_init_agent_without_super(self): + """Test that subclass can override init_agent without calling super().""" + custom_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal custom_init_called + custom_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert custom_init_called + # Verify auto-registration still happened via registry + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_subclass_can_add_manual_hooks(self): + """Test that subclass can manually add hooks in addition to decorated ones.""" + manual_hook_added = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeModelCallEvent): + pass + + def manual_hook(self, event: BeforeInvocationEvent): + pass + + def init_agent(self, agent): + nonlocal manual_hook_added + # Add manual hook - no super() needed + agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) + manual_hook_added = True + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert manual_hook_added + # Verify both hooks registered (1 manual + 1 auto) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginAsyncInitPlugin: + """Tests for async init_agent support.""" + + @pytest.mark.asyncio + async def test_async_init_agent_supported(self): + """Test that async init_agent is supported.""" + async_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + async def init_agent(self, agent): + nonlocal async_init_called + async_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify async init was called (run_async handles it) + assert async_init_called + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginBoundMethods: + """Tests for bound method registration.""" + + def test_hooks_are_bound_to_instance(self): + """Test that registered hooks are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Call the registered hook and verify it accesses the correct instance + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + callbacks = list(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) + callbacks[0](mock_event) + + assert len(plugin.events_received) == 1 + assert plugin.events_received[0] is mock_event + + def test_tools_are_bound_to_instance(self): + """Test that registered tools are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.tool_called = False + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + self.tool_called = True + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Get the tool that was registered and call it + call_args = mock_agent.tool_registry.process_tools.call_args + registered_tools = call_args[0][0] + assert len(registered_tools) == 1 + + # Call the tool - it should be bound to the instance + result = registered_tools[0]("test") + assert plugin.tool_called + assert result == "test" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index c16cfcf7a..04b39718b 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -4,38 +4,39 @@ import pytest +from strands.hooks import HookRegistry from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry -# Plugin Tests +# Plugin Base Class Tests -def test_plugin_class_requires_inheritance(): - """Test that Plugin class requires inheritance.""" +def test_plugin_base_class_isinstance_check(): + """Test that Plugin subclass passes isinstance check.""" class MyPlugin(Plugin): name = "my-plugin" - def init_agent(self, agent): - pass - plugin = MyPlugin() assert isinstance(plugin, Plugin) -def test_plugin_class_sync_implementation(): - """Test Plugin class works with synchronous init_agent.""" +def test_plugin_base_class_sync_implementation(): + """Test Plugin base class works with synchronous init_agent.""" class SyncPlugin(Plugin): name = "sync-plugin" def init_agent(self, agent): + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,19 +46,22 @@ def init_agent(self, agent): @pytest.mark.asyncio -async def test_plugin_class_async_implementation(): - """Test Plugin class works with asynchronous init_agent.""" +async def test_plugin_base_class_async_implementation(): + """Test Plugin base class works with asynchronous init_agent.""" class AsyncPlugin(Plugin): name = "async-plugin" async def init_agent(self, agent): + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -78,42 +82,37 @@ def init_agent(self, agent): PluginWithoutName() -def test_plugin_class_requires_init_agent_method(): - """Test that Plugin class requires an init_agent method.""" +def test_plugin_base_class_requires_init_agent_method(): + """Test that Plugin base class provides default init_agent.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): + class PluginWithoutOverride(Plugin): + name = "no-override-plugin" - class PluginWithoutInitPlugin(Plugin): - name = "incomplete-plugin" + plugin = PluginWithoutOverride() + # Plugin base class provides default init_agent + assert hasattr(plugin, "init_agent") + assert callable(plugin.init_agent) - PluginWithoutInitPlugin() - -def test_plugin_class_with_class_attribute_name(): - """Test Plugin class works when name is a class attribute.""" +def test_plugin_base_class_with_class_attribute_name(): + """Test Plugin base class works when name is a class attribute.""" class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" - def init_agent(self, agent): - pass - plugin = PluginWithClassAttribute() assert isinstance(plugin, Plugin) assert plugin.name == "class-attr-plugin" -def test_plugin_class_with_property_name(): - """Test Plugin class works when name is a property.""" +def test_plugin_base_class_with_property_name(): + """Test Plugin base class works when name is a property.""" class PluginWithProperty(Plugin): @property - def name(self): + def name(self) -> str: return "property-plugin" - def init_agent(self, agent): - pass - plugin = PluginWithProperty() assert isinstance(plugin, Plugin) assert plugin.name == "property-plugin" @@ -125,7 +124,11 @@ def init_agent(self, agent): @pytest.fixture def mock_agent(): """Create a mock agent for testing.""" - return unittest.mock.Mock() + agent = unittest.mock.Mock() + agent.hooks = HookRegistry() + agent.tool_registry = unittest.mock.MagicMock() + agent.add_hook = unittest.mock.Mock() + return agent @pytest.fixture @@ -141,9 +144,11 @@ class TestPlugin(Plugin): name = "test-plugin" def __init__(self): + super().__init__() self.initialized = False def init_agent(self, agent): + # No super() needed - registry handles auto-registration self.initialized = True agent.plugin_initialized = True @@ -160,9 +165,6 @@ def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): class TestPlugin(Plugin): name = "test-plugin" - def init_agent(self, agent): - pass - plugin1 = TestPlugin() plugin2 = TestPlugin() @@ -179,9 +181,11 @@ class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): + super().__init__() self.initialized = False async def init_agent(self, agent): + # No super() needed - registry handles auto-registration self.initialized = True agent.async_plugin_initialized = True