diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 33c1693208..e82f444993 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -24,7 +24,7 @@ import click from google.genai import types as genai_types -from ..agents.llm_agent import Agent +from ..agents.base_agent import BaseAgent from ..evaluation.base_eval_service import BaseEvalService from ..evaluation.base_eval_service import EvaluateConfig from ..evaluation.base_eval_service import EvaluateRequest @@ -86,11 +86,18 @@ def get_default_metric_info( ) -def get_root_agent(agent_module_file_path: str) -> Agent: +async def get_root_agent(agent_module_file_path: str) -> BaseAgent: """Returns root agent given the agent module.""" agent_module = _get_agent_module(agent_module_file_path) - root_agent = agent_module.agent.root_agent - return root_agent + agent_module_with_agent = getattr(agent_module, "agent", agent_module) + if hasattr(agent_module_with_agent, "root_agent"): + return agent_module_with_agent.root_agent + elif hasattr(agent_module_with_agent, "get_agent_async"): + root_agent, _ = await agent_module_with_agent.get_agent_async() + return root_agent + raise ValueError( + "Agent module should have either `root_agent` or `get_agent_async`." + ) def try_get_reset_func(agent_module_file_path: str) -> Any: diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 5b5d3e5c82..fec7dd5a1a 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -812,7 +812,7 @@ def cli_eval( print(f"Using evaluation criteria: {eval_config}") eval_metrics = get_eval_metrics_from_config(eval_config) - root_agent = get_root_agent(agent_module_file_path) + root_agent = asyncio.run(get_root_agent(agent_module_file_path)) app_name = os.path.basename(agent_module_file_path) agents_dir = os.path.dirname(agent_module_file_path) eval_sets_manager = None diff --git a/tests/unittests/cli/utils/test_cli_eval.py b/tests/unittests/cli/utils/test_cli_eval.py index c6d21fa707..c5f9d1918a 100644 --- a/tests/unittests/cli/utils/test_cli_eval.py +++ b/tests/unittests/cli/utils/test_cli_eval.py @@ -19,6 +19,8 @@ from types import SimpleNamespace from unittest import mock +import pytest + def test_get_eval_sets_manager_local(monkeypatch): mock_local_manager = mock.MagicMock() @@ -49,3 +51,46 @@ def test_get_eval_sets_manager_gcs(monkeypatch): ) assert manager == mock_gcs_manager mock_create_gcs.assert_called_once_with("gs://bucket") + + +@pytest.mark.asyncio +async def test_get_root_agent_supports_root_agent(monkeypatch): + root_agent = mock.MagicMock() + agent_module = SimpleNamespace(agent=SimpleNamespace(root_agent=root_agent)) + monkeypatch.setattr( + "google.adk.cli.cli_eval._get_agent_module", + lambda _agent_module_file_path: agent_module, + ) + from google.adk.cli.cli_eval import get_root_agent + + assert await get_root_agent("some/dir") == root_agent + + +@pytest.mark.asyncio +async def test_get_root_agent_supports_get_agent_async(monkeypatch): + root_agent = mock.MagicMock() + get_agent_async = mock.AsyncMock(return_value=(root_agent, object())) + agent_module = SimpleNamespace( + agent=SimpleNamespace(get_agent_async=get_agent_async) + ) + monkeypatch.setattr( + "google.adk.cli.cli_eval._get_agent_module", + lambda _agent_module_file_path: agent_module, + ) + from google.adk.cli.cli_eval import get_root_agent + + assert await get_root_agent("some/dir") == root_agent + get_agent_async.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_root_agent_raises_without_supported_entrypoint(monkeypatch): + agent_module = SimpleNamespace(agent=SimpleNamespace()) + monkeypatch.setattr( + "google.adk.cli.cli_eval._get_agent_module", + lambda _agent_module_file_path: agent_module, + ) + from google.adk.cli.cli_eval import get_root_agent + + with pytest.raises(ValueError, match="root_agent|get_agent_async"): + await get_root_agent("some/dir") diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 5977f90dd7..0ebd181dd2 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -59,7 +59,9 @@ def mock_load_eval_set_from_file(): @pytest.fixture def mock_get_root_agent(): - with mock.patch("google.adk.cli.cli_eval.get_root_agent") as mock_func: + with mock.patch( + "google.adk.cli.cli_eval.get_root_agent", new_callable=mock.AsyncMock + ) as mock_func: mock_func.return_value = root_agent yield mock_func