Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/google/adk/cli/cli_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@
from ..evaluation.eval_case import get_all_tool_calls
from ..evaluation.eval_case import IntermediateDataType
from ..evaluation.eval_metrics import EvalMetric
from ..evaluation.eval_metrics import Interval
from ..evaluation.eval_metrics import MetricInfo
from ..evaluation.eval_metrics import MetricValueInfo
from ..evaluation.eval_result import EvalCaseResult
from ..evaluation.eval_sets_manager import EvalSetsManager
from ..evaluation.metric_defaults import get_default_metric_info
from ..utils.context_utils import Aclosing

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -73,19 +71,6 @@ def _get_agent_module(agent_module_file_path: str):
return _import_from_path(module_name, file_path)


def get_default_metric_info(
metric_name: str, description: str = ""
) -> MetricInfo:
"""Returns a default MetricInfo for a metric."""
return MetricInfo(
metric_name=metric_name,
description=description,
metric_value_info=MetricValueInfo(
interval=Interval(min_value=0.0, max_value=1.0)
),
)


def get_root_agent(agent_module_file_path: str) -> Agent:
"""Returns root agent given the agent module."""
agent_module = _get_agent_module(agent_module_file_path)
Expand Down
21 changes: 21 additions & 0 deletions src/google/adk/evaluation/agent_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..agents.base_agent import BaseAgent
from ..utils.context_utils import Aclosing
from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from .custom_metric_evaluator import _CustomMetricEvaluator
from .eval_case import get_all_tool_calls
from .eval_case import IntermediateDataType
from .eval_case import Invocation
Expand All @@ -50,6 +51,9 @@
from .evaluator import EvalStatus
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
from .local_eval_sets_manager import convert_eval_set_to_pydantic_schema
from .metric_defaults import get_default_metric_info
from .metric_evaluator_registry import _get_default_metric_evaluator_registry
from .metric_evaluator_registry import MetricEvaluatorRegistry
from .simulation.user_simulator_provider import UserSimulatorProvider

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -154,13 +158,28 @@ async def evaluate_eval_set(
user_simulator_config=eval_config.user_simulator_config
)

metric_evaluator_registry = _get_default_metric_evaluator_registry()
if eval_config.custom_metrics:
for metric_name, config in eval_config.custom_metrics.items():
if config.metric_info:
metric_info = config.metric_info.model_copy()
metric_info.metric_name = metric_name
else:
metric_info = get_default_metric_info(
metric_name=metric_name, description=config.description
)
metric_evaluator_registry.register_evaluator(
metric_info, _CustomMetricEvaluator
)

# Step 1: Perform evals, basically inferencing and evaluation of metrics
eval_results_by_eval_id = await AgentEvaluator._get_eval_results_by_eval_id(
agent_for_eval=agent_for_eval,
eval_set=eval_set,
eval_metrics=eval_metrics,
num_runs=num_runs,
user_simulator_provider=user_simulator_provider,
metric_evaluator_registry=metric_evaluator_registry,
)

# Step 2: Post-process the results!
Expand Down Expand Up @@ -536,6 +555,7 @@ async def _get_eval_results_by_eval_id(
eval_metrics: list[EvalMetric],
num_runs: int,
user_simulator_provider: UserSimulatorProvider,
metric_evaluator_registry: Optional[MetricEvaluatorRegistry] = None,
) -> dict[str, list[EvalCaseResult]]:
"""Returns EvalCaseResults grouped by eval case id.

Expand All @@ -560,6 +580,7 @@ async def _get_eval_results_by_eval_id(
app_name=app_name, eval_set=eval_set
),
user_simulator_provider=user_simulator_provider,
metric_evaluator_registry=metric_evaluator_registry,
)

inference_requests = [
Expand Down
32 changes: 32 additions & 0 deletions src/google/adk/evaluation/metric_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .eval_metrics import Interval
from .eval_metrics import MetricInfo
from .eval_metrics import MetricValueInfo


def get_default_metric_info(
metric_name: str, description: str = ""
) -> MetricInfo:
"""Returns a default MetricInfo for a metric."""
return MetricInfo(
metric_name=metric_name,
description=description,
metric_value_info=MetricValueInfo(
interval=Interval(min_value=0.0, max_value=1.0)
),
)
4 changes: 3 additions & 1 deletion src/google/adk/evaluation/metric_evaluator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
class MetricEvaluatorRegistry:
"""A registry for metric Evaluators."""

_registry: dict[str, tuple[type[Evaluator], MetricInfo]] = {}
def __init__(self):
"""Initializes an empty registry."""
self._registry: dict[str, tuple[type[Evaluator], MetricInfo]] = {}

def get_evaluator(self, eval_metric: EvalMetric) -> Evaluator:
"""Returns an Evaluator for the given metric.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Optional

from google.adk.evaluation.eval_case import ConversationScenario
from google.adk.evaluation.eval_case import get_all_tool_calls
from google.adk.evaluation.eval_case import Invocation
from google.adk.evaluation.eval_metrics import EvalMetric
from google.adk.evaluation.eval_metrics import EvalStatus
from google.adk.evaluation.evaluator import EvaluationResult
from google.adk.evaluation.evaluator import PerInvocationResult


def tool_trajectory_length_match(
eval_metric: EvalMetric,
actual_invocations: list[Invocation],
expected_invocations: Optional[list[Invocation]] = None,
conversation_scenario: Optional[ConversationScenario] = None,
) -> EvaluationResult:
del eval_metric
del conversation_scenario
expected_invocations = expected_invocations or []

per_invocation_results = []
for idx, actual in enumerate(actual_invocations):
expected = (
expected_invocations[idx] if idx < len(expected_invocations) else None
)
actual_tools = get_all_tool_calls(actual.intermediate_data)
expected_tools = (
get_all_tool_calls(expected.intermediate_data) if expected else []
)
match = len(actual_tools) == len(expected_tools)
per_invocation_results.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=1.0 if match else 0.0,
eval_status=EvalStatus.PASSED if match else EvalStatus.FAILED,
)
)

overall_score = (
sum(r.score for r in per_invocation_results) / len(per_invocation_results)
if per_invocation_results
else 0.0
)
overall_eval_status = (
EvalStatus.PASSED if overall_score == 1.0 else EvalStatus.FAILED
)
return EvaluationResult(
overall_score=overall_score,
overall_eval_status=overall_eval_status,
per_invocation_results=per_invocation_results,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"eval_set_id": "custom_metrics_eval_set",
"name": "custom_metrics_eval_set",
"description": "Custom metric evaluation sample.",
"eval_cases": [
{
"eval_id": "tests/integration/fixture/home_automation_agent/test_files/custom_metrics/simple_custom_metric.test.json",
"conversation": [
{
"invocation_id": "a9e4f840-7f1e-4b69-b9c1-3b85c03a60a4",
"user_content": {
"parts": [
{
"video_metadata": null,
"thought": null,
"code_execution_result": null,
"executable_code": null,
"file_data": null,
"function_call": null,
"function_response": null,
"inline_data": null,
"text": "Turn off device_2 in the Bedroom."
}
],
"role": "user"
},
"final_response": {
"parts": [
{
"video_metadata": null,
"thought": null,
"code_execution_result": null,
"executable_code": null,
"file_data": null,
"function_call": null,
"function_response": null,
"inline_data": null,
"text": "I have set the device_2 status to off."
}
],
"role": "model"
},
"intermediate_data": {
"tool_uses": [
{
"id": null,
"args": {
"location": "Bedroom",
"device_id": "device_2",
"status": "OFF"
},
"name": "set_device_info"
}
],
"intermediate_responses": []
},
"creation_timestamp": 1747337309.2360144
}
],
"session_input": null,
"creation_timestamp": 1747337309.2360282
}
],
"creation_timestamp": 1747337309.2360387
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"criteria": {
"tool_trajectory_length_match": 1.0
},
"custom_metrics": {
"tool_trajectory_length_match": {
"code_config": {
"name": "tests.integration.fixture.home_automation_agent.test_files.custom_metrics.metrics.tool_trajectory_length_match"
},
"description": "Checks that actual and expected tool trajectories have the same length."
}
}
}
12 changes: 12 additions & 0 deletions tests/integration/test_with_test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ async def test_with_single_test_file():
)


@pytest.mark.asyncio
async def test_with_custom_metric():
"""Test eval with a custom metric."""
await AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir=(
"tests/integration/fixture/home_automation_agent/test_files/custom_metrics/simple_custom_metric.test.json"
),
num_runs=1,
)


@pytest.mark.asyncio
async def test_with_folder_of_test_files_long_running():
"""Test the agent's basic ability via a folder of session files."""
Expand Down
Loading