Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,37 @@
from ...agent.agent import Agent

from ...hooks import BeforeModelCallEvent, HookRegistry
from ...types.content import Messages
from ...types.content import ContentBlock, Messages
from ...types.exceptions import ContextWindowOverflowException
from ...types.tools import ToolResultContent
from .conversation_manager import ConversationManager

logger = logging.getLogger(__name__)

_PRESERVE_CHARS = 200


class SlidingWindowConversationManager(ConversationManager):
"""Implements a sliding window strategy for managing conversation history.
This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
invalid window states.
When truncation is enabled (the default), large tool results are partially truncated, preserving the first
and last 200 characters, and image blocks inside tool results are replaced with descriptive text placeholders.
Truncation targets the oldest tool results first so the most relevant recent context is preserved as long
as possible.
Supports proactive management during agent loop execution via the per_turn parameter.
"""

def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False):
def __init__(
self,
window_size: int = 40,
should_truncate_results: bool = True,
*,
per_turn: bool | int = False,
):
"""Initialize the sliding window conversation manager.
Args:
Expand All @@ -44,6 +58,9 @@ def __init__(self, window_size: int = 40, should_truncate_results: bool = True,
Raises:
ValueError: If per_turn is 0 or a negative integer.
"""
if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0:
raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}")

super().__init__()

self.window_size = window_size
Expand Down Expand Up @@ -157,14 +174,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
messages = agent.messages

# Try to truncate the tool result first
last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages)
if last_message_idx_with_tool_results is not None and self.should_truncate_results:
oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages)
if oldest_message_idx_with_tool_results is not None and self.should_truncate_results:
logger.debug(
"message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results
"message_index=<%s> | found message with tool results at index", oldest_message_idx_with_tool_results
)
results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results)
results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results)
if results_truncated:
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results)
return

# Try to trim index id when tool result cannot be truncated anymore
Expand Down Expand Up @@ -197,10 +214,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
messages[:] = messages[trim_index:]

def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results in a message to reduce context size.
"""Truncate tool results and replace image blocks in a message to reduce context size.
For text blocks within tool results, all blocks are partially truncated unless they
have already been truncated. The first and last _PRESERVE_CHARS characters are kept,
and the removed middle is replaced with a notice indicating how many characters were
removed. The tool result status is not changed.
When a message contains tool results that are too large for the model's context window, this function
replaces the content of those tool results with a simple error message.
Image blocks nested inside tool result content are replaced with a short descriptive placeholder.
Args:
messages: The conversation message history.
Expand All @@ -212,52 +233,82 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
if msg_idx >= len(messages) or msg_idx < 0:
return False

def _image_placeholder(image_block: Any) -> str:
source: Any = image_block.get("source", {})
media_type = image_block.get("format", "unknown")
data = source.get("bytes", b"")
return f"[image: {media_type}, {len(data) if data else 0} bytes]"

message = messages[msg_idx]
changes_made = False
tool_result_too_large_message = "The tool result was too large!"
for i, content in enumerate(message.get("content", [])):
if isinstance(content, dict) and "toolResult" in content:
tool_result_content_text = next(
(item["text"] for item in content["toolResult"]["content"] if "text" in item),
"",
)
# make the overwriting logic togglable
if (
message["content"][i]["toolResult"]["status"] == "error"
and tool_result_content_text == tool_result_too_large_message
):
logger.info("ToolResult has already been updated, skipping overwrite")
return False
# Update status to error with informative message
message["content"][i]["toolResult"]["status"] = "error"
message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
changes_made = True
new_content: list[ContentBlock] = []

for content in message.get("content", []):
if "toolResult" in content:
tool_result: Any = content["toolResult"]
tool_result_items = tool_result.get("content", [])
new_items: list[ToolResultContent] = []
item_changed = False

for item in tool_result_items:
# Replace image items nested inside toolResult content
if "image" in item:
new_items.append({"text": _image_placeholder(item["image"])})
item_changed = True
continue

# Partially truncate text items that have not already been truncated
if "text" in item:
text = item["text"]
truncation_marker = "... [truncated:"
if truncation_marker not in text and len(text) > 2 * _PRESERVE_CHARS:
prefix = text[:_PRESERVE_CHARS]
suffix = text[-_PRESERVE_CHARS:]
removed = len(text) - 2 * _PRESERVE_CHARS
truncated_text = (
f"{prefix}...\n\n... [truncated: {removed} chars removed] ...\n\n...{suffix}"
)
new_items.append({"text": truncated_text})
item_changed = True
continue

new_items.append(item)

if item_changed:
updated_tool_result: Any = {
**{k: v for k, v in tool_result.items() if k != "content"},
"content": new_items,
}
new_content.append({"toolResult": updated_tool_result})
changes_made = True
else:
new_content.append(content)
continue

new_content.append(content)

if changes_made:
message["content"] = new_content

return changes_made

def _find_last_message_with_tool_results(self, messages: Messages) -> int | None:
"""Find the index of the last message containing tool results.
def _find_oldest_message_with_tool_results(self, messages: Messages) -> int | None:
"""Find the index of the oldest message containing tool results.
This is useful for identifying messages that might need to be truncated to reduce context size.
Iterates from oldest to newest so that truncation targets the least-recent
(and therefore least relevant) tool results first.
Args:
messages: The conversation message history.
Returns:
Index of the last message with tool results, or None if no such message exists.
Index of the oldest message with tool results, or None if no such message exists.
"""
# Iterate backwards through all messages (from newest to oldest)
for idx in range(len(messages) - 1, -1, -1):
# Check if this message has any content with toolResult
# Iterate from oldest to newest
for idx in range(len(messages)):
current_message = messages[idx]
has_tool_result = False

for content in current_message.get("content", []):
if isinstance(content, dict) and "toolResult" in content:
has_tool_result = True
break

if has_tool_result:
return idx
return idx

return None
10 changes: 6 additions & 4 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene
},
},
},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "' + "X" * 500 + '"}'}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "tool_use"}},
]
Expand All @@ -635,12 +635,14 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene

agent("test message")

large_input = "X" * 500
truncated_text = large_input[:200] + "...\n\n... [truncated: 100 chars removed] ...\n\n..." + large_input[-200:]
expected_messages = [
{"role": "user", "content": [{"text": "test message"}]},
{
"role": "assistant",
"content": [
{"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": "abcdEfghI123"}}}
{"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": large_input}}}
],
},
{
Expand All @@ -649,8 +651,8 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene
{
"toolResult": {
"toolUseId": "t1",
"status": "error",
"content": [{"text": "The tool result was too large!"}],
"status": "success",
"content": [{"text": truncated_text}],
}
}
],
Expand Down
Loading
Loading