diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index a40eb0f45..e7cdbe131 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -48,8 +48,10 @@ jobs: with: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run integration tests env: AWS_REGION: us-east-1 diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index bf2c9f21d..7c96a9789 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -34,9 +34,11 @@ jobs: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | python -m pip install --upgrade pip - pip install hatch twine + pip install hatch twine 'virtualenv<21' - name: Validate version run: | diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 89cc459de..5f5aa6fcd 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -83,8 +83,10 @@ jobs: # Windows typically has audio libraries available by default echo "Windows audio dependencies handled by PyAudio wheels" - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run Unit tests id: tests run: hatch test tests --cover @@ -118,8 +120,10 @@ jobs: sudo apt-get install -y portaudio19-dev libasound2-dev - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run lint id: lint diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 70552d6ba..0f91349d2 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -43,6 +43,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import asyncio import functools import inspect +import json import logging from collections.abc import Callable from typing import ( @@ -61,6 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import docstring_parser from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo +from pydantic_core import PydanticSerializationError from typing_extensions import override from ..interrupt import InterruptException @@ -644,12 +646,25 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format - # Always include at least one content item for consistency + # Serialize to JSON for consistent, parseable output (except strings) + if isinstance(result, str): + text = result + elif isinstance(result, BaseModel): + try: + text = result.model_dump_json() + except PydanticSerializationError: + text = str(result) + else: + try: + text = json.dumps(result) + except (TypeError, ValueError): + text = str(result) + return ToolResultEvent( { "toolUseId": tool_use_d, "status": "success", - "content": [{"text": str(result)}], + "content": [{"text": text}], }, exception=exception, ) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index f3d6eda02..cc1158983 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -136,7 +136,7 @@ def identity(a: int, agent: dict = None): tru_events = await alist(stream) exp_events = [ - ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": '[2, {"state": 1}]'}]}) ] assert tru_events == exp_events @@ -595,12 +595,12 @@ def none_return_tool(param: str) -> None: assert result["tool_result"]["status"] == "success" assert result["tool_result"]["content"][0]["text"] == "Result: test" - # Test None return - should still create valid ToolResult with "None" text + # Test None return - should still create valid ToolResult with "null" stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -861,7 +861,7 @@ def int_return_tool(param: str) -> int: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" # Define tool with Union return type @strands.tool @@ -884,10 +884,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert ( - "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] - or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] - ) + assert result["tool_result"]["content"][0]["text"] == '{"key": "value"}' tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) @@ -901,7 +898,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -992,6 +989,132 @@ def custom_result_tool(param: str) -> dict[str, Any]: assert result["tool_result"]["content"][1]["type"] == "markdown" +@pytest.mark.asyncio +async def test_tool_result_json_serialization_dict(alist): + """Test that dict results are serialized as JSON.""" + + @strands.tool + def dict_tool() -> dict: + """Returns a dict.""" + return {"key": "value", "number": 42} + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = dict_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"key": "value", "number": 42}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_list(alist): + """Test that list results are serialized as JSON.""" + + @strands.tool + def list_tool() -> list: + """Returns a list.""" + return [1, "two", {"three": 3}] + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = list_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '[1, "two", {"three": 3}]' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic(alist): + """Test that Pydantic model results are serialized as JSON.""" + from pydantic import BaseModel + + class MyModel(BaseModel): + name: str + count: int + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model.""" + return MyModel(name="test", count=5) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"name":"test","count":5}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic_non_serializable(alist): + """Test that Pydantic models with non-serializable fields fall back to str().""" + from pydantic import BaseModel + + class NonSerializable: + def __repr__(self): + return "NonSerializable()" + + class MyModel(BaseModel): + model_config = {"arbitrary_types_allowed": True} + data: NonSerializable + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model with non-serializable field.""" + return MyModel(data=NonSerializable()) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "data=NonSerializable()" + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_non_serializable(alist): + """Test that non-JSON-serializable results fall back to str().""" + + class CustomClass: + def __str__(self): + return "custom_str_repr" + + @strands.tool + def custom_tool() -> Any: + """Returns a non-serializable object.""" + return CustomClass() + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = custom_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "custom_str_repr" + + +@pytest.mark.asyncio +async def test_tool_result_string_not_json_encoded(alist): + """Test that string results are NOT JSON-encoded (no extra quotes).""" + + @strands.tool + def string_tool() -> str: + """Returns a string.""" + return "hello world" + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = string_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "hello world" + + def test_docstring_parsing(): """Test that function docstring is correctly parsed into tool spec."""