Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
from ._validation import _has_location_source, validate_config_keys
from .model import Model

Expand Down Expand Up @@ -280,11 +280,46 @@ def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai
tools.extend(self.config["gemini_tools"])
return tools

def _map_tool_choice(self, tool_choice: ToolChoice | None) -> genai.types.ToolConfig | None:
"""Map Strands ToolChoice to Gemini ToolConfig.

Args:
tool_choice: Strands tool choice configuration.

Returns:
Gemini ToolConfig or None for default behavior.
"""
if tool_choice is None:
return None

if "auto" in tool_choice:
return genai.types.ToolConfig(
function_calling_config=genai.types.FunctionCallingConfig(
mode=genai.types.FunctionCallingConfigMode.AUTO
)
)
elif "any" in tool_choice:
return genai.types.ToolConfig(
function_calling_config=genai.types.FunctionCallingConfig(
mode=genai.types.FunctionCallingConfigMode.ANY
)
)
elif "tool" in tool_choice:
return genai.types.ToolConfig(
function_calling_config=genai.types.FunctionCallingConfig(
mode=genai.types.FunctionCallingConfigMode.ANY,
allowed_function_names=[cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]],
)
)

return None

def _format_request_config(
self,
tool_specs: list[ToolSpec] | None,
system_prompt: str | None,
params: dict[str, Any] | None,
tool_choice: ToolChoice | None = None,
) -> genai.types.GenerateContentConfig:
"""Format Gemini request config.

Expand All @@ -294,13 +329,18 @@ def _format_request_config(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
params: Additional model parameters (e.g., temperature).
tool_choice: Selection strategy for tool invocation.

Returns:
Gemini request config.
"""
# Only map tool_choice if tools are provided
tool_config = self._map_tool_choice(tool_choice) if tool_specs else None

return genai.types.GenerateContentConfig(
system_instruction=system_prompt,
tools=self._format_request_tools(tool_specs),
tool_config=tool_config,
**(params or {}),
)

Expand All @@ -310,6 +350,7 @@ def _format_request(
tool_specs: list[ToolSpec] | None,
system_prompt: str | None,
params: dict[str, Any] | None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Format a Gemini streaming request.

Expand All @@ -320,12 +361,13 @@ def _format_request(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
params: Additional model parameters (e.g., temperature).
tool_choice: Selection strategy for tool invocation.

Returns:
A Gemini streaming request.
"""
return {
"config": self._format_request_config(tool_specs, system_prompt, params).to_json_dict(),
"config": self._format_request_config(tool_specs, system_prompt, params, tool_choice).to_json_dict(),
"contents": [content.to_json_dict() for content in self._format_request_content(messages)],
"model": self.config["model_id"],
}
Expand Down Expand Up @@ -449,7 +491,6 @@ async def stream(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
Note: Currently unused.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -458,7 +499,7 @@ async def stream(
Raises:
ModelThrottledException: If the request is throttled by Gemini.
"""
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"), tool_choice)

client = self._get_client().aio

Expand Down
54 changes: 54 additions & 0 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,60 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)


@pytest.mark.parametrize(
("tool_choice", "expected_tool_config"),
[
(None, None),
({"auto": {}}, {"function_calling_config": {"mode": "AUTO"}}),
({"any": {}}, {"function_calling_config": {"mode": "ANY"}}),
(
{"tool": {"name": "my_tool"}},
{"function_calling_config": {"mode": "ANY", "allowed_function_names": ["my_tool"]}},
),
],
)
def test_map_tool_choice(model, tool_choice, expected_tool_config):
"""Test _map_tool_choice returns correct Gemini ToolConfig."""
result = model._map_tool_choice(tool_choice)
if expected_tool_config is None:
assert result is None
else:
assert result.to_json_dict() == expected_tool_config


@pytest.mark.asyncio
async def test_stream_request_with_tool_choice(gemini_client, model, model_id, tool_spec):
"""Test that tool_choice is passed through to the request."""
tool_choice = {"any": {}}
await anext(model.stream([], tool_specs=[tool_spec], tool_choice=tool_choice))

call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs
assert call_kwargs["config"]["tool_config"] == {"function_calling_config": {"mode": "ANY"}}


@pytest.mark.asyncio
async def test_stream_request_with_tool_choice_specific_tool(gemini_client, model, model_id, tool_spec):
"""Test that tool_choice with specific tool name is passed through."""
tool_choice = {"tool": {"name": "my_tool"}}
await anext(model.stream([], tool_specs=[tool_spec], tool_choice=tool_choice))

call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs
assert call_kwargs["config"]["tool_config"] == {
"function_calling_config": {"mode": "ANY", "allowed_function_names": ["my_tool"]}
}


@pytest.mark.asyncio
async def test_stream_request_tool_choice_ignored_without_tools(gemini_client, model, messages, model_id):
"""Test that tool_choice is ignored when no tools are provided."""
tool_choice = {"any": {}}
await anext(model.stream(messages, tool_choice=tool_choice))

call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs
# tool_config should not be present when no tools are provided
assert call_kwargs["config"].get("tool_config") is None


@pytest.mark.asyncio
async def test_stream_request_with_tool_use(gemini_client, model, model_id):
"""Test toolUse with reasoningSignature is sent as function_call with thought_signature."""
Expand Down