diff --git a/src/google/adk/utils/instructions_utils.py b/src/google/adk/utils/instructions_utils.py index 505b5cf128..67e2c405ed 100644 --- a/src/google/adk/utils/instructions_utils.py +++ b/src/google/adk/utils/instructions_utils.py @@ -21,10 +21,10 @@ from ..sessions.state import State __all__ = [ - 'inject_session_state', + "inject_session_state", ] -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) async def inject_session_state( @@ -76,18 +76,29 @@ async def _async_sub(pattern, repl_async_fn, string) -> str: result.append(replacement) last_end = match.end() result.append(string[last_end:]) - return ''.join(result) + return "".join(result) async def _replace_match(match) -> str: - var_name = match.group().lstrip('{').rstrip('}').strip() + matched_text = match.group() + + # Check for exactly double braces (escaping) + if ( + matched_text.startswith("{{") + and matched_text.endswith("}}") + and not matched_text.startswith("{{{") + and not matched_text.endswith("}}}") + ): + return matched_text[1:-1] + + var_name = matched_text.lstrip("{").rstrip("}").strip() optional = False - if var_name.endswith('?'): + if var_name.endswith("?"): optional = True - var_name = var_name.removesuffix('?') - if var_name.startswith('artifact.'): - var_name = var_name.removeprefix('artifact.') + var_name = var_name.removesuffix("?") + if var_name.startswith("artifact."): + var_name = var_name.removeprefix("artifact.") if invocation_context.artifact_service is None: - raise ValueError('Artifact service is not initialized.') + raise ValueError("Artifact service is not initialized.") artifact = await invocation_context.artifact_service.load_artifact( app_name=invocation_context.session.app_name, user_id=invocation_context.session.user_id, @@ -97,11 +108,11 @@ async def _replace_match(match) -> str: if artifact is None: if optional: logger.debug( - 'Artifact %s not found, replacing with empty string', var_name + "Artifact %s not found, replacing with empty string", var_name ) - return '' + return "" else: - raise KeyError(f'Artifact {var_name} not found.') + raise KeyError(f"Artifact {var_name} not found.") return str(artifact) else: if not _is_valid_state_name(var_name): @@ -109,19 +120,19 @@ async def _replace_match(match) -> str: if var_name in invocation_context.session.state: value = invocation_context.session.state[var_name] if value is None: - return '' + return "" return str(value) else: if optional: logger.debug( - 'Context variable %s not found, replacing with empty string', + "Context variable %s not found, replacing with empty string", var_name, ) - return '' + return "" else: - raise KeyError(f'Context variable not found: `{var_name}`.') + raise KeyError(f"Context variable not found: `{var_name}`.") - return await _async_sub(r'{+[^{}]*}+', _replace_match, template) + return await _async_sub(r"{+[^{}]*}+", _replace_match, template) def _is_valid_state_name(var_name): @@ -138,12 +149,12 @@ def _is_valid_state_name(var_name): Returns: True if the variable name is a valid state name, False otherwise. """ - parts = var_name.split(':') + parts = var_name.split(":") if len(parts) == 1: return var_name.isidentifier() if len(parts) == 2: prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX] - if (parts[0] + ':') in prefixes: + if (parts[0] + ":") in prefixes: return parts[1].isidentifier() return False diff --git a/tests/unittests/utils/test_instructions_utils.py b/tests/unittests/utils/test_instructions_utils.py index d76e5032ec..b63e73edbc 100644 --- a/tests/unittests/utils/test_instructions_utils.py +++ b/tests/unittests/utils/test_instructions_utils.py @@ -245,6 +245,20 @@ async def test_inject_session_state_with_optional_missing_artifact_returns_empty assert populated_instruction == "Optional artifact: " +@pytest.mark.asyncio +async def test_inject_session_state_with_triple_braces(): + """Triple braces are not treated as escaping - they substitute normally.""" + instruction_template = "Value: {{{name}}}" + invocation_context = await _create_test_readonly_context( + state={"name": "Alice"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Value: Alice" + + @pytest.mark.asyncio async def test_inject_session_state_with_none_state_value_returns_empty(): instruction_template = "Value: {test_key}" @@ -267,3 +281,74 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty(): instruction_template, invocation_context ) assert populated_instruction == "Optional value: " + + +@pytest.mark.asyncio +async def test_inject_session_state_with_double_brace_escaping(): + instruction_template = "Example: {{user_id}}" + invocation_context = await _create_test_readonly_context() + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Example: {user_id}" + + +@pytest.mark.asyncio +async def test_inject_session_state_with_double_brace_escaping_and_normal_substitution(): + instruction_template = "Hello {name}, example: {{variable}}" + invocation_context = await _create_test_readonly_context( + state={"name": "Alice"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello Alice, example: {variable}" + + +@pytest.mark.asyncio +async def test_inject_session_state_with_python_fstring_example(): + instruction_template = """ +Example Python code: +logger.error(f"User not found: {{user_id}}") +""" + invocation_context = await _create_test_readonly_context() + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + expected = """ +Example Python code: +logger.error(f"User not found: {user_id}") +""" + assert populated_instruction == expected + + +@pytest.mark.asyncio +async def test_inject_session_state_with_typescript_template_literal(): + instruction_template = """ +Example TypeScript code: +console.log(`User: ${{userId}}`); +""" + invocation_context = await _create_test_readonly_context() + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + expected = """ +Example TypeScript code: +console.log(`User: ${userId}`); +""" + assert populated_instruction == expected + + +@pytest.mark.asyncio +async def test_inject_session_state_with_multiple_double_brace_patterns(): + instruction_template = "Examples: {{var1}}, {{var2}}, {{var3}}" + invocation_context = await _create_test_readonly_context() + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Examples: {var1}, {var2}, {var3}"