-
Notifications
You must be signed in to change notification settings - Fork 2.9k
fix: Force AUDIO modality for native-audio models in /run_live (#4206) #4232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a446492
9fc0b22
45ea2bc
4cd51b4
2b88478
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -550,6 +550,28 @@ def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: | |
| return agent_or_app.root_agent | ||
| return agent_or_app | ||
|
|
||
| def _get_effective_modalities( | ||
| self, root_agent: BaseAgent, requested_modalities: List[str] | ||
| ) -> List[str]: | ||
| """Determines effective modalities, forcing AUDIO for native-audio models. | ||
|
|
||
| Native-audio models only support AUDIO modality. This method detects | ||
| native-audio models by checking if the model name contains "native-audio" | ||
| and forces AUDIO modality for those models. | ||
|
|
||
| Args: | ||
| root_agent: The root agent of the application. | ||
| requested_modalities: The modalities requested by the client. | ||
|
|
||
| Returns: | ||
| The effective modalities to use. | ||
| """ | ||
| model = getattr(root_agent, "model", None) | ||
| model_name = model if isinstance(model, str) else "" | ||
| if "native-audio" in model_name: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The string For example: _NATIVE_AUDIO_MODEL_TAG = "native-audio"This would allow you to reference |
||
| return ["AUDIO"] | ||
| return requested_modalities | ||
|
Comment on lines
553
to
573
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve efficiency, this method can be refactored to accept a def _get_effective_modalities(
self, root_agent: BaseAgent, requested_modalities: List[str]
) -> List[str]:
"""Determines effective modalities, forcing AUDIO for native-audio models.
Native-audio models only support AUDIO modality. This method detects
native-audio models by checking if the model name contains "native-audio"
and forces AUDIO modality for those models.
Args:
root_agent: The root agent of the application.
requested_modalities: The modalities requested by the client.
Returns:
The effective modalities to use.
"""
model = getattr(root_agent, "model", None)
model_name = model if isinstance(model, str) else ""
if "native-audio" in model_name:
return ["AUDIO"]
return requested_modalities |
||
|
|
||
| def _create_runner(self, agentic_app: App) -> Runner: | ||
| """Create a runner with common services.""" | ||
| return Runner( | ||
|
|
@@ -1652,7 +1674,10 @@ async def run_agent_live( | |
|
|
||
| async def forward_events(): | ||
| runner = await self.get_runner_async(app_name) | ||
| run_config = RunConfig(response_modalities=modalities) | ||
| effective_modalities = self._get_effective_modalities( | ||
| runner.app.root_agent, modalities | ||
| ) | ||
| run_config = RunConfig(response_modalities=effective_modalities) | ||
|
Comment on lines
1677
to
1680
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the suggested refactoring of Note that you will also need to update the corresponding unit tests for effective_modalities = self._get_effective_modalities(
runner.app.root_agent, modalities
)
run_config = RunConfig(response_modalities=effective_modalities) |
||
| async with Aclosing( | ||
| runner.run_live( | ||
| session=session, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
|
|
||
| from fastapi.testclient import TestClient | ||
| from google.adk.agents.base_agent import BaseAgent | ||
| from google.adk.agents.llm_agent import LlmAgent | ||
| from google.adk.agents.run_config import RunConfig | ||
| from google.adk.apps.app import App | ||
| from google.adk.artifacts.base_artifact_service import ArtifactVersion | ||
|
|
@@ -1411,5 +1412,91 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): | |
| assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() | ||
|
|
||
|
|
||
| def test_native_audio_model_forces_audio_modality(): | ||
| """Test that native-audio models force AUDIO modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| native_audio_agent = LlmAgent( | ||
| name="native_audio_agent", | ||
| model="gemini-live-2.5-flash-native-audio", | ||
| ) | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
Comment on lines
+1424
to
+1433
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The instantiation of Here's an example of what that fixture could look like: @pytest.fixture
def adk_web_server_for_modality_tests():
"""Provides an AdkWebServer instance with mocked services for modality tests."""
from google.adk.cli.adk_web_server import AdkWebServer
return AdkWebServer(
agent_loader=MagicMock(),
session_service=MagicMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)Each test could then accept |
||
|
|
||
| # Test: requesting TEXT should be forced to AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| native_audio_agent, ["TEXT"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
| # Test: requesting AUDIO should stay AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| native_audio_agent, ["AUDIO"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
|
|
||
| def test_non_native_audio_model_keeps_requested_modality(): | ||
| """Test that non-native-audio models keep the requested modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| regular_agent = LlmAgent( | ||
| name="regular_agent", | ||
| model="gemini-2.5-flash", | ||
| ) | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
|
||
| # Test: requesting TEXT should stay TEXT | ||
| modalities = adk_web_server._get_effective_modalities(regular_agent, ["TEXT"]) | ||
| assert modalities == ["TEXT"] | ||
|
|
||
| # Test: requesting AUDIO should stay AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| regular_agent, ["AUDIO"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
|
|
||
| def test_agent_without_model_attribute(): | ||
| """Test that agents without model attribute keep requested modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| base_agent = DummyAgent(name="base_agent") | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
|
||
| # Test: BaseAgent without model attr should keep requested modality | ||
| modalities = adk_web_server._get_effective_modalities(base_agent, ["TEXT"]) | ||
| assert modalities == ["TEXT"] | ||
|
|
||
|
Comment on lines
1415
to
1499
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These new tests are a good start, but they have a couple of issues that make them less effective and harder to maintain:
A better approach would be to refactor the logic for determining the For example, you could add a method to def _get_effective_modalities(self, app_name: str, requested_modalities: list[str]) -> list[str]:
"""Determines the effective modalities, forcing AUDIO for native-audio models."""
agent_or_app = self.agent_loader.load_agent(app_name)
root_agent = self._get_root_agent(agent_or_app)
model = getattr(root_agent, "model", None)
model_name = model if isinstance(model, str) else ""
if "native-audio" in model_name:
return ["AUDIO"]
return requested_modalitiesThen your tests could be simplified to something like this, which is much more direct and robust: def test_native_audio_model_forces_audio_modality():
# ... setup adk_web_server with NativeAudioAgentLoader ...
modalities = adk_web_server._get_effective_modalities("test_app", ["TEXT"])
assert modalities == ["AUDIO"]
def test_non_native_audio_model_keeps_requested_modality():
# ... setup adk_web_server with RegularAgentLoader ...
modalities = adk_web_server._get_effective_modalities("test_app", ["TEXT"])
assert modalities == ["TEXT"]
def test_agent_without_model_attribute():
# ... setup adk_web_server with BaseAgentLoader ...
modalities = adk_web_server._get_effective_modalities("test_app", ["TEXT"])
assert modalities == ["TEXT"]This would make the tests much stronger and easier to maintain.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, will fix this! |
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main(["-xvs", __file__]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for extracting the model name only handles the case where the
modelattribute is a string. TheLlmAgent.modelattribute can also be aBaseLlmobject, in which caseisinstance(model, str)would be false,model_namewould become an empty string, and the check for "native-audio" would fail.To make this more robust, you should also handle the case where
modelis an object (likeBaseLlm) that has amodelstring attribute. It would also be beneficial to add a test case for anLlmAgentinitialized with aBaseLlmobject to ensure full coverage.