Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ jobs:
with:
python-version: '3.10'
- name: Install dependencies
# Pin virtualenv until hatch is fixed.
# See https://github.com/pypa/hatch/issues/2193
run: |
pip install --no-cache-dir hatch
pip install --no-cache-dir hatch 'virtualenv<21'
- name: Run integration tests
env:
AWS_REGION: us-east-1
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/pypi-publish-on-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ jobs:
python-version: '3.10'

- name: Install dependencies
# Pin virtualenv until hatch is fixed.
# See https://github.com/pypa/hatch/issues/2193
run: |
python -m pip install --upgrade pip
pip install hatch twine
pip install hatch twine 'virtualenv<21'

- name: Validate version
run: |
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/test-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ jobs:
# Windows typically has audio libraries available by default
echo "Windows audio dependencies handled by PyAudio wheels"
- name: Install dependencies
# Pin virtualenv until hatch is fixed.
# See https://github.com/pypa/hatch/issues/2193
run: |
pip install --no-cache-dir hatch
pip install --no-cache-dir hatch 'virtualenv<21'
- name: Run Unit tests
id: tests
run: hatch test tests --cover
Expand Down Expand Up @@ -118,8 +120,10 @@ jobs:
sudo apt-get install -y portaudio19-dev libasound2-dev
- name: Install dependencies
# Pin virtualenv until hatch is fixed.
# See https://github.com/pypa/hatch/issues/2193
run: |
pip install --no-cache-dir hatch
pip install --no-cache-dir hatch 'virtualenv<21'
- name: Run lint
id: lint
Expand Down
19 changes: 17 additions & 2 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
import asyncio
import functools
import inspect
import json
import logging
from collections.abc import Callable
from typing import (
Expand All @@ -61,6 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
import docstring_parser
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticSerializationError
from typing_extensions import override

from ..interrupt import InterruptException
Expand Down Expand Up @@ -644,12 +646,25 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception |
return ToolResultEvent(cast(ToolResult, result), exception=exception)
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
# Serialize to JSON for consistent, parseable output (except strings)
if isinstance(result, str):
text = result
elif isinstance(result, BaseModel):
try:
text = result.model_dump_json()
except PydanticSerializationError:
text = str(result)
else:
try:
text = json.dumps(result)
except (TypeError, ValueError):
text = str(result)

return ToolResultEvent(
{
"toolUseId": tool_use_d,
"status": "success",
"content": [{"text": str(result)}],
"content": [{"text": text}],
},
exception=exception,
)
Expand Down
141 changes: 132 additions & 9 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def identity(a: int, agent: dict = None):

tru_events = await alist(stream)
exp_events = [
ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]})
ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": '[2, {"state": 1}]'}]})
]
assert tru_events == exp_events

Expand Down Expand Up @@ -595,12 +595,12 @@ def none_return_tool(param: str) -> None:
assert result["tool_result"]["status"] == "success"
assert result["tool_result"]["content"][0]["text"] == "Result: test"

# Test None return - should still create valid ToolResult with "None" text
# Test None return - should still create valid ToolResult with "null"
stream = none_return_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "success"
assert result["tool_result"]["content"][0]["text"] == "None"
assert result["tool_result"]["content"][0]["text"] == "null"


@pytest.mark.asyncio
Expand Down Expand Up @@ -861,7 +861,7 @@ def int_return_tool(param: str) -> int:

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "success"
assert result["tool_result"]["content"][0]["text"] == "None"
assert result["tool_result"]["content"][0]["text"] == "null"

# Define tool with Union return type
@strands.tool
Expand All @@ -884,10 +884,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None:

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "success"
assert (
"{'key': 'value'}" in result["tool_result"]["content"][0]["text"]
or '{"key": "value"}' in result["tool_result"]["content"][0]["text"]
)
assert result["tool_result"]["content"][0]["text"] == '{"key": "value"}'

tool_use = {"toolUseId": "test-id", "input": {"param": "str"}}
stream = union_return_tool.stream(tool_use, {})
Expand All @@ -901,7 +898,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None:

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "success"
assert result["tool_result"]["content"][0]["text"] == "None"
assert result["tool_result"]["content"][0]["text"] == "null"


@pytest.mark.asyncio
Expand Down Expand Up @@ -992,6 +989,132 @@ def custom_result_tool(param: str) -> dict[str, Any]:
assert result["tool_result"]["content"][1]["type"] == "markdown"


@pytest.mark.asyncio
async def test_tool_result_json_serialization_dict(alist):
"""Test that dict results are serialized as JSON."""

@strands.tool
def dict_tool() -> dict:
"""Returns a dict."""
return {"key": "value", "number": 42}

tool_use = {"toolUseId": "test-id", "input": {}}
stream = dict_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == '{"key": "value", "number": 42}'


@pytest.mark.asyncio
async def test_tool_result_json_serialization_list(alist):
"""Test that list results are serialized as JSON."""

@strands.tool
def list_tool() -> list:
"""Returns a list."""
return [1, "two", {"three": 3}]

tool_use = {"toolUseId": "test-id", "input": {}}
stream = list_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == '[1, "two", {"three": 3}]'


@pytest.mark.asyncio
async def test_tool_result_json_serialization_pydantic(alist):
"""Test that Pydantic model results are serialized as JSON."""
from pydantic import BaseModel

class MyModel(BaseModel):
name: str
count: int

@strands.tool
def pydantic_tool() -> MyModel:
"""Returns a Pydantic model."""
return MyModel(name="test", count=5)

tool_use = {"toolUseId": "test-id", "input": {}}
stream = pydantic_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == '{"name":"test","count":5}'


@pytest.mark.asyncio
async def test_tool_result_json_serialization_pydantic_non_serializable(alist):
"""Test that Pydantic models with non-serializable fields fall back to str()."""
from pydantic import BaseModel

class NonSerializable:
def __repr__(self):
return "NonSerializable()"

class MyModel(BaseModel):
model_config = {"arbitrary_types_allowed": True}
data: NonSerializable

@strands.tool
def pydantic_tool() -> MyModel:
"""Returns a Pydantic model with non-serializable field."""
return MyModel(data=NonSerializable())

tool_use = {"toolUseId": "test-id", "input": {}}
stream = pydantic_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == "data=NonSerializable()"


@pytest.mark.asyncio
async def test_tool_result_json_serialization_non_serializable(alist):
"""Test that non-JSON-serializable results fall back to str()."""

class CustomClass:
def __str__(self):
return "custom_str_repr"

@strands.tool
def custom_tool() -> Any:
"""Returns a non-serializable object."""
return CustomClass()

tool_use = {"toolUseId": "test-id", "input": {}}
stream = custom_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == "custom_str_repr"


@pytest.mark.asyncio
async def test_tool_result_string_not_json_encoded(alist):
"""Test that string results are NOT JSON-encoded (no extra quotes)."""

@strands.tool
def string_tool() -> str:
"""Returns a string."""
return "hello world"

tool_use = {"toolUseId": "test-id", "input": {}}
stream = string_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
text = result["tool_result"]["content"][0]["text"]

assert text == "hello world"


def test_docstring_parsing():
"""Test that function docstring is correctly parsed into tool spec."""

Expand Down
Loading