From 68e57eb3ef3c0291dc2f2ed6e1d4ecfccfdff3f9 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Delafosse Date: Fri, 27 Feb 2026 09:28:21 +0100 Subject: [PATCH 1/2] feat(gemini): implement tool_choice support Add tool_choice parameter support to GeminiModel, enabling Strands' structured output retry mechanism to work correctly with Gemini. - Add _map_tool_choice() helper to convert Strands ToolChoice to Gemini's ToolConfig with FunctionCallingConfigMode - Update _format_request_config() to accept and apply tool_choice - Update _format_request() to pass tool_choice through - Update stream() to pass tool_choice to request formatting Mapping: - {"auto": {}} -> FunctionCallingConfigMode.AUTO - {"any": {}} -> FunctionCallingConfigMode.ANY - {"tool": {"name": "X"}} -> FunctionCallingConfigMode.ANY with allowed_function_names=["X"] Closes #1129 --- src/strands/models/gemini.py | 49 +++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c94570293..e75139005 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -18,7 +18,7 @@ from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolChoice, ToolSpec +from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec from ._validation import _has_location_source, validate_config_keys from .model import Model @@ -280,11 +280,46 @@ def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai tools.extend(self.config["gemini_tools"]) return tools + def _map_tool_choice(self, tool_choice: ToolChoice | None) -> genai.types.ToolConfig | None: + """Map Strands ToolChoice to Gemini ToolConfig. + + Args: + tool_choice: Strands tool choice configuration. + + Returns: + Gemini ToolConfig or None for default behavior. + """ + if tool_choice is None: + return None + + if "auto" in tool_choice: + return genai.types.ToolConfig( + function_calling_config=genai.types.FunctionCallingConfig( + mode=genai.types.FunctionCallingConfigMode.AUTO + ) + ) + elif "any" in tool_choice: + return genai.types.ToolConfig( + function_calling_config=genai.types.FunctionCallingConfig( + mode=genai.types.FunctionCallingConfigMode.ANY + ) + ) + elif "tool" in tool_choice: + return genai.types.ToolConfig( + function_calling_config=genai.types.FunctionCallingConfig( + mode=genai.types.FunctionCallingConfigMode.ANY, + allowed_function_names=[cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]], + ) + ) + + return None + def _format_request_config( self, tool_specs: list[ToolSpec] | None, system_prompt: str | None, params: dict[str, Any] | None, + tool_choice: ToolChoice | None = None, ) -> genai.types.GenerateContentConfig: """Format Gemini request config. @@ -294,13 +329,18 @@ def _format_request_config( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. params: Additional model parameters (e.g., temperature). + tool_choice: Selection strategy for tool invocation. Returns: Gemini request config. """ + # Only map tool_choice if tools are provided + tool_config = self._map_tool_choice(tool_choice) if tool_specs else None + return genai.types.GenerateContentConfig( system_instruction=system_prompt, tools=self._format_request_tools(tool_specs), + tool_config=tool_config, **(params or {}), ) @@ -310,6 +350,7 @@ def _format_request( tool_specs: list[ToolSpec] | None, system_prompt: str | None, params: dict[str, Any] | None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Gemini streaming request. @@ -320,12 +361,13 @@ def _format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. params: Additional model parameters (e.g., temperature). + tool_choice: Selection strategy for tool invocation. Returns: A Gemini streaming request. """ return { - "config": self._format_request_config(tool_specs, system_prompt, params).to_json_dict(), + "config": self._format_request_config(tool_specs, system_prompt, params, tool_choice).to_json_dict(), "contents": [content.to_json_dict() for content in self._format_request_content(messages)], "model": self.config["model_id"], } @@ -449,7 +491,6 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. - Note: Currently unused. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -458,7 +499,7 @@ async def stream( Raises: ModelThrottledException: If the request is throttled by Gemini. """ - request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) + request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"), tool_choice) client = self._get_client().aio From e229c9e4aad9e86cb7528035e6b49c39a32b1484 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Delafosse Date: Fri, 27 Feb 2026 09:28:48 +0100 Subject: [PATCH 2/2] test(gemini): add unit tests for tool_choice mapping Add comprehensive tests for the new tool_choice functionality: - test_map_tool_choice: Parametrized test covering None, auto, any, and specific tool cases - test_stream_request_with_tool_choice: Verify tool_choice passes through to request config - test_stream_request_with_tool_choice_specific_tool: Verify specific tool name is included in allowed_function_names - test_stream_request_tool_choice_ignored_without_tools: Verify tool_choice is ignored when no tools are provided --- tests/strands/models/test_gemini.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index ba4b2b53f..398f5e2a4 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -258,6 +258,60 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) +@pytest.mark.parametrize( + ("tool_choice", "expected_tool_config"), + [ + (None, None), + ({"auto": {}}, {"function_calling_config": {"mode": "AUTO"}}), + ({"any": {}}, {"function_calling_config": {"mode": "ANY"}}), + ( + {"tool": {"name": "my_tool"}}, + {"function_calling_config": {"mode": "ANY", "allowed_function_names": ["my_tool"]}}, + ), + ], +) +def test_map_tool_choice(model, tool_choice, expected_tool_config): + """Test _map_tool_choice returns correct Gemini ToolConfig.""" + result = model._map_tool_choice(tool_choice) + if expected_tool_config is None: + assert result is None + else: + assert result.to_json_dict() == expected_tool_config + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_choice(gemini_client, model, model_id, tool_spec): + """Test that tool_choice is passed through to the request.""" + tool_choice = {"any": {}} + await anext(model.stream([], tool_specs=[tool_spec], tool_choice=tool_choice)) + + call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs + assert call_kwargs["config"]["tool_config"] == {"function_calling_config": {"mode": "ANY"}} + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_choice_specific_tool(gemini_client, model, model_id, tool_spec): + """Test that tool_choice with specific tool name is passed through.""" + tool_choice = {"tool": {"name": "my_tool"}} + await anext(model.stream([], tool_specs=[tool_spec], tool_choice=tool_choice)) + + call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs + assert call_kwargs["config"]["tool_config"] == { + "function_calling_config": {"mode": "ANY", "allowed_function_names": ["my_tool"]} + } + + +@pytest.mark.asyncio +async def test_stream_request_tool_choice_ignored_without_tools(gemini_client, model, messages, model_id): + """Test that tool_choice is ignored when no tools are provided.""" + tool_choice = {"any": {}} + await anext(model.stream(messages, tool_choice=tool_choice)) + + call_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs + # tool_config should not be present when no tools are provided + assert call_kwargs["config"].get("tool_config") is None + + @pytest.mark.asyncio async def test_stream_request_with_tool_use(gemini_client, model, model_id): """Test toolUse with reasoningSignature is sent as function_call with thought_signature."""