From 10178ce2c68c7dd3302b087cdadab36123e67200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Mon, 9 Feb 2026 11:30:36 +0800 Subject: [PATCH 1/4] feat(memory_collection): add MemoryConversation class for dialogue history management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I5426a1c75c42e92c568de1e953b6422e24ff240e Co-developed-by: Cursor Signed-off-by: 久氢 --- agentrun/memory_collection/__init__.py | 2 + .../memory_collection/memory_conversation.py | 487 ++++++++++++++++++ examples/server_with_memory.py | 87 ++++ .../test_memory_conversation.py | 270 ++++++++++ 4 files changed, 846 insertions(+) create mode 100644 agentrun/memory_collection/memory_conversation.py create mode 100644 examples/server_with_memory.py create mode 100644 tests/unittests/memory_collection/test_memory_conversation.py diff --git a/agentrun/memory_collection/__init__.py b/agentrun/memory_collection/__init__.py index 6f3bd0f..2b1a949 100644 --- a/agentrun/memory_collection/__init__.py +++ b/agentrun/memory_collection/__init__.py @@ -6,6 +6,7 @@ from .client import MemoryCollectionClient from .memory_collection import MemoryCollection +from .memory_conversation import MemoryConversation from .model import ( EmbedderConfig, EmbedderConfigConfig, @@ -23,6 +24,7 @@ __all__ = [ "MemoryCollection", "MemoryCollectionClient", + "MemoryConversation", "MemoryCollectionCreateInput", "MemoryCollectionUpdateInput", "MemoryCollectionListInput", diff --git a/agentrun/memory_collection/memory_conversation.py b/agentrun/memory_collection/memory_conversation.py new file mode 100644 index 0000000..0903547 --- /dev/null +++ b/agentrun/memory_collection/memory_conversation.py @@ -0,0 +1,487 @@ +"""AgentRun Memory Conversation / AgentRun 记忆对话 + +提供与 TableStore Memory 的集成能力,自动存储用户和 Agent 的对话历史。 + +Example (基本使用): + >>> from agentrun.server import AgentRunServer, AgentRequest + >>> from agentrun.memory_collection import MemoryConversation + >>> + >>> # 初始化 Memory Conversation + >>> memory = MemoryConversation(memory_collection_name="my-memory") + >>> + >>> # 包装 invoke_agent 函数 + >>> async def invoke_agent(req: AgentRequest): + ... async for event in memory.wrap_invoke_agent(req, my_agent_handler): + ... yield event + >>> + >>> server = AgentRunServer(invoke_agent=invoke_agent) + >>> server.start() +""" + +import json +import os +from typing import Any, AsyncIterator, Callable, Dict, Optional, TYPE_CHECKING +import uuid + +import tablestore + +from agentrun.utils.config import Config +from agentrun.utils.log import logger + +if TYPE_CHECKING: + from agentrun.server.model import ( + AgentEvent, + AgentRequest, + EventType, + MessageRole, + ) + + +class MemoryConversation: + """Memory Conversation / 记忆对话 + + 自动将用户和 Agent 的对话存储到 TableStore Memory 中。 + + Attributes: + memory_collection_name: MemoryCollection 名称 + config: AgentRun 配置 + user_id_extractor: 从请求中提取 user_id 的函数 + session_id_extractor: 从请求中提取 session_id 的函数 + agent_id_extractor: 从请求中提取 agent_id 的函数 + """ + + def __init__( + self, + memory_collection_name: str, + config: Optional[Config] = None, + user_id_extractor: Optional[Callable[[Any], str]] = None, + session_id_extractor: Optional[Callable[[Any], str]] = None, + agent_id_extractor: Optional[Callable[[Any], str]] = None, + ): + """初始化 Memory Conversation + + Args: + memory_collection_name: MemoryCollection 名称 + config: AgentRun 配置(可选,默认从环境变量读取) + user_id_extractor: 从请求中提取 user_id 的函数(可选) + session_id_extractor: 从请求中提取 session_id 的函数(可选) + agent_id_extractor: 从请求中提取 agent_id 的函数(可选) + """ + self.memory_collection_name = memory_collection_name + self.config = config or Config() + self.user_id_extractor = ( + user_id_extractor or self._default_user_id_extractor + ) + self.session_id_extractor = ( + session_id_extractor or self._default_session_id_extractor + ) + self.agent_id_extractor = ( + agent_id_extractor or self._default_agent_id_extractor + ) + + # 延迟初始化 + self._memory_store = None + self._ots_client = None + + @staticmethod + def _default_user_id_extractor(req: Any) -> str: + """默认的 user_id 提取器 + + 优先级: + 1. X-User-ID 请求头 + 2. user_id 查询参数 + 3. 默认值 "default_user" + """ + if req.raw_request: + # 从请求头获取 + user_id = req.raw_request.headers.get("X-User-ID") + if user_id: + return user_id + + # 从查询参数获取 + user_id = req.raw_request.query_params.get("user_id") + if user_id: + return user_id + + return "default_user" + + @staticmethod + def _default_session_id_extractor(req: Any) -> str: + """默认的 session_id 提取器 + + 优先级: + 1. X-Session-ID 请求头 + 2. sessionId 查询参数 + 3. 从最后一条消息的 id 生成 + 4. 生成新的 UUID + """ + if req.raw_request: + # 从请求头获取 + session_id = req.raw_request.headers.get("X-Conversation-ID") + if session_id: + return session_id + + # 从查询参数获取 + session_id = req.raw_request.query_params.get("sessionId") + if session_id: + return session_id + + # 从消息 ID 生成(如果有) + if req.messages and req.messages[-1].id: + return f"session_{req.messages[-1].id}" + + # 生成新的 session_id + return f"session_{uuid.uuid4().hex[:16]}" + + @staticmethod + def _default_agent_id_extractor(req: Any) -> str: + """默认的 agent_id 提取器 + + 优先级: + 1. X-Agent-ID 请求头 + 2. 从 URL 路径中提取 /agent-runtimes/{agent_id}/... 格式 + 3. 默认值 "default_agent" + """ + if req.raw_request: + # 从请求头获取 + agent_id = req.raw_request.headers.get("X-Agent-ID") + if agent_id: + return agent_id + + # 从 URL 路径中提取 + # 例如:/agent-runtimes/agent-quick-xFGD/invoke -> agent-quick-xFGD + try: + path = ( + req.raw_request.url.path + if hasattr(req.raw_request.url, "path") + else str(req.raw_request.url) + ) + if "/agent-runtimes/" in path: + # 提取 /agent-runtimes/ 后面的部分 + parts = path.split("/agent-runtimes/", 1) + if len(parts) > 1: + # 获取下一个路径段 + agent_part = parts[1].split("/")[0] + if agent_part: + return agent_part + except Exception: + pass + + return "default_agent" + + async def _get_memory_store(self): + """获取或创建 AsyncMemoryStore 实例""" + if self._memory_store is not None: + return self._memory_store + + try: + # 导入依赖 + from tablestore_for_agent_memory.base.base_memory_store import ( + Message, + Session, + ) + from tablestore_for_agent_memory.base.common import ( + microseconds_timestamp, + ) + from tablestore_for_agent_memory.memory.async_memory_store import ( + AsyncMemoryStore, + ) + except ImportError as e: + raise ImportError( + "tablestore-for-agent-memory package is required. " + "Install it with: pip install tablestore-for-agent-memory" + ) from e + + # 从 MemoryCollection 获取配置 + ots_config = await self._get_ots_config_from_memory_collection() + + # 创建 AsyncOTSClient + self._ots_client = tablestore.AsyncOTSClient( + end_point=ots_config["endpoint"], + access_key_id=ots_config["access_key_id"], + access_key_secret=ots_config["access_key_secret"], + instance_name=ots_config["instance_name"], + ) + + # 配置会话表的二级索引元数据字段 + # agent_id 字段用于标识会话所属的 Agent + from tablestore_for_agent_memory.base.common import MetaType + + session_secondary_index_meta = { + "agent_id": ( + MetaType.STRING + ), # Agent 标识符,用于区分不同的 AI Agent + } + + # 配置会话表的搜索索引结构 + # agent_id: Agent 标识符,支持精确匹配查询 + session_search_index_schema = [ + tablestore.FieldSchema("agent_id", tablestore.FieldType.KEYWORD), + ] + + # 配置消息表的搜索索引结构(消息表不需要额外索引) + message_search_index_schema = [] + + # 创建 AsyncMemoryStore + self._memory_store = AsyncMemoryStore( + tablestore_client=self._ots_client, + session_secondary_index_meta=session_secondary_index_meta, + session_search_index_schema=session_search_index_schema, + message_search_index_schema=message_search_index_schema, + ) + + # 初始化表和索引(如果表已存在会忽略错误) + try: + logger.info( + "Initializing tables and indexes for collection:" + f" {self.memory_collection_name}" + ) + await self._memory_store.init_table() + await self._memory_store.init_search_index() + logger.info(f"Tables and indexes initialized successfully") + except Exception as e: + # 如果表已存在,会抛出异常,这是正常的 + logger.info( + "Tables and indexes already exist or initialization" + f" skipped: {e}" + ) + + logger.info( + "Memory Store initialized for collection:" + f" {self.memory_collection_name}" + ) + + return self._memory_store + + async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]: + """从 MemoryCollection 获取 OTS 配置信息 + + Returns: + Dict[str, Any]: OTS 配置字典,包含: + - endpoint: OTS endpoint + - access_key_id: 访问密钥 ID + - access_key_secret: 访问密钥 Secret + - instance_name: OTS 实例名称 + """ + from agentrun.memory_collection import MemoryCollection + + # 获取 MemoryCollection + memory_collection = await MemoryCollection.get_by_name_async( + self.memory_collection_name, config=self.config + ) + + if not memory_collection.vector_store_config: + raise ValueError( + f"MemoryCollection {self.memory_collection_name} does not have " + "vector_store_config" + ) + + vector_store_config = memory_collection.vector_store_config + provider = vector_store_config.provider or "" + + # 只支持 aliyun_tablestore provider + if provider != "aliyun_tablestore": + raise ValueError( + f"Only aliyun_tablestore provider is supported, got: {provider}" + ) + + if not vector_store_config.config: + raise ValueError( + f"MemoryCollection {self.memory_collection_name} does not have " + "vector_store_config.config" + ) + + vs_config = vector_store_config.config + + # 获取 endpoint 并根据运行环境决定是否转换为公网地址 + endpoint = vs_config.endpoint or "" + is_running_on_fc = os.getenv("FC_REGION") is not None + if not is_running_on_fc and ".vpc.tablestore.aliyuncs.com" in endpoint: + original_endpoint = endpoint + endpoint = endpoint.replace( + ".vpc.tablestore.aliyuncs.com", ".tablestore.aliyuncs.com" + ) + logger.info( + "Running on local, converted VPC endpoint to public endpoint:" + f" {original_endpoint} -> {endpoint}" + ) + + # 构建 OTS 配置 + ots_config = { + "endpoint": endpoint, + "instance_name": vs_config.instance_name or "", + "access_key_id": self.config.get_access_key_id(), + "access_key_secret": self.config.get_access_key_secret(), + } + + return ots_config + + async def wrap_invoke_agent( + self, + request: Any, + agent_handler: Callable[[Any], AsyncIterator[Any]], + ) -> AsyncIterator[Any]: + """包装 invoke_agent 函数,自动存储对话历史 + + Args: + request: AgentRequest 对象 + agent_handler: 原始的 agent 处理函数 + + Yields: + Any: Agent 返回的事件或字符串 + + Example: + >>> async def my_agent(req: AgentRequest): + ... yield "Hello, world!" + >>> + >>> async def invoke_agent(req: AgentRequest): + ... async for event in memory.wrap_invoke_agent(req, my_agent): + ... yield event + """ + try: + # 导入依赖 + from tablestore_for_agent_memory.base.base_memory_store import ( + Message, + Session, + ) + from tablestore_for_agent_memory.base.common import ( + microseconds_timestamp, + ) + + from agentrun.server.model import AgentEvent, EventType, MessageRole + except ImportError as e: + logger.warning( + "tablestore-for-agent-memory not installed, skipping memory" + " storage" + ) + # 如果没有安装依赖,直接透传 + async for event in agent_handler(request): + yield event + return + + # 提取 user_id、session_id 和 agent_id + user_id = self.user_id_extractor(request) + session_id = self.session_id_extractor(request) + agent_id = self.agent_id_extractor(request) + + logger.debug( + f"Memory: user_id={user_id}, session_id={session_id}," + f" agent_id={agent_id}" + ) + + # 获取 MemoryStore + try: + memory_store = await self._get_memory_store() + except Exception as e: + logger.error( + f"Failed to initialize memory store: {e}", exc_info=True + ) + # 初始化失败,直接透传 + async for event in agent_handler(request): + yield event + return + + # 创建或更新 Session + current_time = microseconds_timestamp() + session = Session( + user_id=user_id, + session_id=session_id, + update_time=current_time, + metadata={"agent_id": agent_id}, + ) + + try: + await memory_store.put_session(session) + except Exception as e: + logger.error(f"Failed to save session: {e}", exc_info=True) + + # 构建输入消息列表(包含所有历史消息) + input_messages = [] + for msg in request.messages: + input_messages.append({ + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": self._extract_message_content(msg.content), + }) + + # 收集 Agent 响应 + agent_response_content = "" + + try: + # 流式处理 Agent 响应 + async for event in agent_handler(request): + # 收集文本内容 + if isinstance(event, str): + agent_response_content += event + elif isinstance(event, AgentEvent): + if event.event == EventType.TEXT and "delta" in event.data: + agent_response_content += event.data["delta"] + + # 透传事件 + yield event + + # 保存完整的对话轮次(输入 + 输出) + if agent_response_content: + try: + # 将助手响应添加到消息列表 + output_messages = input_messages + [{ + "role": "assistant", + "content": agent_response_content, + }] + + # 将完整的对话历史存储为一条消息 + # content 字段存储 JSON 格式的消息列表 + conversation_message = Message( + session_id=session_id, + message_id=f"msg_{uuid.uuid4().hex[:16]}", + content=json.dumps(output_messages, ensure_ascii=False), + ) + await memory_store.put_message(conversation_message) + + # 更新 Session 时间 + session.update_time = microseconds_timestamp() + await memory_store.update_session(session) + + logger.debug( + f"Saved conversation: {len(output_messages)} messages, " + f"response length: {len(agent_response_content)} chars" + ) + except Exception as e: + logger.error( + f"Failed to save conversation: {e}", exc_info=True + ) + + except Exception as e: + logger.error(f"Error in agent handler: {e}", exc_info=True) + raise + + @staticmethod + def _extract_message_content(content: Any) -> str: + """提取消息内容为字符串 + + Args: + content: 消息内容(可能是字符串或多模态内容列表) + + Returns: + str: 提取的文本内容 + """ + if isinstance(content, str): + return content + elif isinstance(content, list): + # 多模态内容,提取文本部分 + text_parts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + return " ".join(text_parts) + else: + return str(content) if content else "" + + async def close(self): + """关闭 OTS 客户端连接""" + if self._ots_client: + await self._ots_client.close() + logger.info("Memory Store connection closed") diff --git a/examples/server_with_memory.py b/examples/server_with_memory.py new file mode 100644 index 0000000..0d6f025 --- /dev/null +++ b/examples/server_with_memory.py @@ -0,0 +1,87 @@ +"""AgentRun Server with Memory Integration Example / 带记忆集成的 AgentRun Server 示例 + +演示如何使用 MemoryIntegration 自动存储对话历史到 TableStore。 + +运行前需要设置环境变量: + export AGENTRUN_ACCESS_KEY_ID="your-access-key-id" + export AGENTRUN_ACCESS_KEY_SECRET="your-access-key-secret" + export AGENTRUN_REGION="cn-hangzhou" + export MODEL_SERVICE="your-model-service" + export MODEL_NAME="qwen3-max" + export SANDBOX_NAME="your-sandbox" + export MEMORY_COLLECTION_NAME="your-memory-collection" + +运行示例: + uv run python examples/server_with_memory.py +""" + +import os + +from langchain.agents import create_agent + +from agentrun import AgentRequest +from agentrun.integration.langchain import ( + AgentRunConverter, + model, + sandbox_toolset, +) +from agentrun.memory_collection import MemoryConversation +from agentrun.sandbox import TemplateType +from agentrun.server import AgentRunServer + +# 配置参数 +MODEL_SERVICE = os.getenv("MODEL_SERVICE", "qwen3-max") +MODEL_NAME = os.getenv("MODEL_NAME", "qwen3-max") +SANDBOX_NAME = os.getenv("SANDBOX_NAME", "") +MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "mem-ots0129") + +# 创建 Agent +agent = create_agent( + # 使用 AgentRun 注册的模型 + model=model(MODEL_SERVICE, model=MODEL_NAME), + system_prompt=""" +你是一个诗人,根据用户输入内容写一个20字以内的诗文 +""", + # 使用 AgentRun 的 Sandbox 工具 + # tools=[*sandbox_toolset(SANDBOX_NAME, template_type=TemplateType.BROWSER)], +) + +# 初始化 Memory Integration +memory = MemoryConversation(memory_collection_name=MEMORY_COLLECTION_NAME) + + +async def invoke_agent(req: AgentRequest): + """Agent 调用函数,集成了记忆存储功能""" + try: + converter = AgentRunConverter() + + # 定义原始的 agent 处理函数 + async def agent_handler(request: AgentRequest): + result = agent.astream_events( + { + "messages": [ + {"role": msg.role, "content": msg.content} + for msg in request.messages + ] + }, + config={"recursion_limit": 1000}, + ) + async for event in result: + for agentrun_event in converter.convert(event): + yield agentrun_event + + # 使用 MemoryIntegration 包装,自动存储对话历史 + async for event in memory.wrap_invoke_agent(req, agent_handler): + yield event + + except Exception as e: + print(f"Error in invoke_agent: {e}") + raise Exception("Internal Error") + + +# 创建并启动 Server +if __name__ == "__main__": + server = AgentRunServer(invoke_agent=invoke_agent) + print(f"Server starting with memory collection: {MEMORY_COLLECTION_NAME}") + print("Memory will be automatically saved to TableStore") + server.start(port=9000) diff --git a/tests/unittests/memory_collection/test_memory_conversation.py b/tests/unittests/memory_collection/test_memory_conversation.py new file mode 100644 index 0000000..9624898 --- /dev/null +++ b/tests/unittests/memory_collection/test_memory_conversation.py @@ -0,0 +1,270 @@ +"""Tests for AgentRun Memory Conversation / AgentRun 记忆对话测试""" + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from agentrun.memory_collection import MemoryConversation +from agentrun.server.model import AgentRequest, Message, MessageRole + + +@pytest.fixture +def mock_memory_collection(): + """Mock MemoryCollection""" + with patch("agentrun.memory_collection.MemoryCollection") as mock: + # Mock get_by_name_async + mock_collection = MagicMock() + mock_collection.vector_store_config = MagicMock() + mock_collection.vector_store_config.provider = "aliyun_tablestore" + mock_collection.vector_store_config.config = MagicMock() + mock_collection.vector_store_config.config.endpoint = ( + "https://test.cn-hangzhou.ots.aliyuncs.com" + ) + mock_collection.vector_store_config.config.instance_name = ( + "test-instance" + ) + + mock.get_by_name_async = AsyncMock(return_value=mock_collection) + yield mock + + +@pytest.fixture +def mock_memory_store(): + """Mock AsyncMemoryStore""" + with patch( + "tablestore_for_agent_memory.memory.async_memory_store.AsyncMemoryStore" + ) as mock_store_class: + mock_store = AsyncMock() + mock_store.put_session = AsyncMock() + mock_store.put_message = AsyncMock() + mock_store.update_session = AsyncMock() + mock_store.init_table = AsyncMock() + mock_store.init_search_index = AsyncMock() + mock_store_class.return_value = mock_store + yield mock_store + + +@pytest.fixture +def mock_ots_client(): + """Mock AsyncOTSClient""" + with patch("tablestore.AsyncOTSClient") as mock: + mock_client = AsyncMock() + mock.return_value = mock_client + yield mock_client + + +@pytest.fixture +def mock_request(): + """Create a mock Starlette Request""" + mock_req = Mock() + mock_headers = Mock() + mock_headers.get = Mock(return_value="user123") + mock_query = Mock() + mock_query.get = Mock(return_value=None) + + mock_req.headers = mock_headers + mock_req.query_params = mock_query + mock_req.client = None + + return mock_req + + +class TestMemoryConversation: + """Test MemoryConversation class""" + + def test_default_user_id_extractor(self, mock_request): + """Test default user_id extraction""" + # Test with X-User-ID header + request = AgentRequest.model_construct( + messages=[], + raw_request=mock_request, + ) + + user_id = MemoryConversation._default_user_id_extractor(request) + assert user_id == "user123" + + def test_default_user_id_extractor_fallback(self): + """Test user_id extraction fallback to default""" + request = AgentRequest(messages=[]) + + user_id = MemoryConversation._default_user_id_extractor(request) + assert user_id == "default_user" + + def test_default_session_id_extractor(self): + """Test default session_id extraction""" + # Test with X-Session-ID header + mock_req = Mock() + mock_headers = Mock() + mock_headers.get = Mock(return_value="session456") + mock_query = Mock() + mock_query.get = Mock(return_value=None) + + mock_req.headers = mock_headers + mock_req.query_params = mock_query + + request = AgentRequest.model_construct( + messages=[], + raw_request=mock_req, + ) + + session_id = MemoryConversation._default_session_id_extractor(request) + assert session_id == "session456" + + def test_default_session_id_extractor_from_message(self): + """Test session_id extraction from message ID""" + request = AgentRequest( + messages=[ + Message(id="msg123", role=MessageRole.USER, content="Hello") + ] + ) + + session_id = MemoryConversation._default_session_id_extractor(request) + assert session_id == "session_msg123" + + def test_default_session_id_extractor_generate(self): + """Test session_id generation""" + request = AgentRequest(messages=[]) + + session_id = MemoryConversation._default_session_id_extractor(request) + assert session_id.startswith("session_") + + def test_extract_message_content_string(self): + """Test extracting string content""" + content = "Hello, world!" + result = MemoryConversation._extract_message_content(content) + assert result == "Hello, world!" + + def test_extract_message_content_multimodal(self): + """Test extracting multimodal content""" + content = [ + {"type": "text", "text": "Hello"}, + {"type": "image", "url": "https://example.com/image.jpg"}, + {"type": "text", "text": "World"}, + ] + result = MemoryConversation._extract_message_content(content) + assert result == "Hello World" + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_basic( + self, mock_memory_collection, mock_memory_store, mock_ots_client + ): + """Test basic wrap_invoke_agent functionality""" + # Create MemoryConversation + memory = MemoryConversation(memory_collection_name="test-memory") + + # Mock agent handler + async def mock_agent(request: AgentRequest): + yield "Hello" + yield ", " + yield "world!" + + # Create request + request = AgentRequest( + messages=[Message(role=MessageRole.USER, content="Hi there")] + ) + + # Wrap and collect results + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + # Verify results + assert results == ["Hello", ", ", "world!"] + + # Verify memory store calls + assert mock_memory_store.put_session.called + assert mock_memory_store.put_message.called + assert mock_memory_store.update_session.called + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_with_custom_extractors( + self, mock_memory_collection, mock_memory_store, mock_ots_client + ): + """Test wrap_invoke_agent with custom extractors""" + + # Custom extractors + def custom_user_extractor(req: AgentRequest) -> str: + return "custom_user" + + def custom_session_extractor(req: AgentRequest) -> str: + return "custom_session" + + # Create MemoryConversation with custom extractors + memory = MemoryConversation( + memory_collection_name="test-memory", + user_id_extractor=custom_user_extractor, + session_id_extractor=custom_session_extractor, + ) + + # Mock agent handler + async def mock_agent(request: AgentRequest): + yield "Response" + + # Create request + request = AgentRequest( + messages=[Message(role=MessageRole.USER, content="Test")] + ) + + # Wrap and collect results + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + # Verify results + assert results == ["Response"] + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_handles_errors( + self, mock_memory_collection, mock_memory_store, mock_ots_client + ): + """Test that memory errors don't break agent responses""" + # Make memory store raise error + mock_memory_store.put_session.side_effect = Exception("Storage error") + + # Create MemoryConversation + memory = MemoryConversation(memory_collection_name="test-memory") + + # Mock agent handler + async def mock_agent(request: AgentRequest): + yield "Still works!" + + # Create request + request = AgentRequest( + messages=[Message(role=MessageRole.USER, content="Test")] + ) + + # Wrap and collect results - should still work + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + # Verify agent still responds + assert results == ["Still works!"] + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_without_dependencies(self): + """Test graceful fallback when dependencies not installed""" + memory = MemoryConversation(memory_collection_name="test-memory") + + # Force _memory_store to None to simulate uninitialized state + memory._memory_store = None + + # Mock _get_memory_store to raise ImportError + async def mock_get_memory_store(): + raise ImportError("Module not found") + + memory._get_memory_store = mock_get_memory_store + + async def mock_agent(request: AgentRequest): + yield "Response" + + request = AgentRequest( + messages=[Message(role=MessageRole.USER, content="Test")] + ) + + # Should still work, just without storage + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + assert results == ["Response"] From c4bd09e1aac2334c145a32e8ea9fe54f4b3f1dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Mon, 9 Feb 2026 11:43:18 +0800 Subject: [PATCH 2/4] feat(memory_collection): add MemoryConversation class for dialogue history management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I3c52d07b82111f00bfb3fd8b04bee13cd5ce8892 Co-developed-by: Cursor Signed-off-by: 久氢 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c7447f7..63966da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ server = [ "fastapi>=0.104.0", "uvicorn>=0.24.0", "ag-ui-protocol>=0.1.10", + "tablestore-for-agent-memory>=1.1.2", ] langchain = [ From 5c78ab69eb2f86b2c0f7383a87efbc456f65223e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Mon, 9 Feb 2026 19:22:02 +0800 Subject: [PATCH 3/4] feat(memory_collection): add MemoryConversation class for dialogue history management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I099853b5465234a34db594c9dea1210743174932 Co-developed-by: Cursor Signed-off-by: 久氢 --- .../memory_collection/memory_conversation.py | 152 ++++++++++---- agentrun/server/server.py | 54 +++++ examples/server_with_memory.py | 56 ++--- .../test_memory_conversation.py | 192 +++++++++++++++++- 4 files changed, 374 insertions(+), 80 deletions(-) diff --git a/agentrun/memory_collection/memory_conversation.py b/agentrun/memory_collection/memory_conversation.py index 0903547..6035288 100644 --- a/agentrun/memory_collection/memory_conversation.py +++ b/agentrun/memory_collection/memory_conversation.py @@ -2,25 +2,19 @@ 提供与 TableStore Memory 的集成能力,自动存储用户和 Agent 的对话历史。 -Example (基本使用): - >>> from agentrun.server import AgentRunServer, AgentRequest - >>> from agentrun.memory_collection import MemoryConversation - >>> - >>> # 初始化 Memory Conversation - >>> memory = MemoryConversation(memory_collection_name="my-memory") - >>> - >>> # 包装 invoke_agent 函数 - >>> async def invoke_agent(req: AgentRequest): - ... async for event in memory.wrap_invoke_agent(req, my_agent_handler): - ... yield event - >>> - >>> server = AgentRunServer(invoke_agent=invoke_agent) - >>> server.start() """ import json import os -from typing import Any, AsyncIterator, Callable, Dict, Optional, TYPE_CHECKING +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + TYPE_CHECKING, +) import uuid import tablestore @@ -88,13 +82,17 @@ def _default_user_id_extractor(req: Any) -> str: """默认的 user_id 提取器 优先级: - 1. X-User-ID 请求头 + 1. X-User-ID 请求头(支持多种格式) 2. user_id 查询参数 3. 默认值 "default_user" """ if req.raw_request: # 从请求头获取 - user_id = req.raw_request.headers.get("X-User-ID") + user_id = ( + req.raw_request.headers.get("X-AgentRun-User-ID") + or req.raw_request.headers.get("x-agentrun-user-id") + or req.raw_request.headers.get("X-Agentrun-User-Id") + ) if user_id: return user_id @@ -110,26 +108,20 @@ def _default_session_id_extractor(req: Any) -> str: """默认的 session_id 提取器 优先级: - 1. X-Session-ID 请求头 - 2. sessionId 查询参数 - 3. 从最后一条消息的 id 生成 - 4. 生成新的 UUID + 1. X-Session-ID 请求头(支持多种格式) + 2. 生成新的 UUID """ if req.raw_request: - # 从请求头获取 - session_id = req.raw_request.headers.get("X-Conversation-ID") - if session_id: - return session_id - - # 从查询参数获取 - session_id = req.raw_request.query_params.get("sessionId") + # 从请求头获取(兼容多种格式) + # 支持:X-AgentRun-Session-ID, x-agentrun-session-id, X-Agentrun-Session-Id + session_id = ( + req.raw_request.headers.get("X-AgentRun-Session-ID") + or req.raw_request.headers.get("x-agentrun-session-id") + or req.raw_request.headers.get("X-Agentrun-Session-Id") + ) if session_id: return session_id - # 从消息 ID 生成(如果有) - if req.messages and req.messages[-1].id: - return f"session_{req.messages[-1].id}" - # 生成新的 session_id return f"session_{uuid.uuid4().hex[:16]}" @@ -138,13 +130,18 @@ def _default_agent_id_extractor(req: Any) -> str: """默认的 agent_id 提取器 优先级: - 1. X-Agent-ID 请求头 + 1. X-Agent-ID 请求头(支持多种格式) 2. 从 URL 路径中提取 /agent-runtimes/{agent_id}/... 格式 3. 默认值 "default_agent" """ if req.raw_request: - # 从请求头获取 - agent_id = req.raw_request.headers.get("X-Agent-ID") + # 从请求头获取(兼容多种格式) + # 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id, X-Agent-ID, x-agent-id + agent_id = ( + req.raw_request.headers.get("X-AgentRun-Agent-ID") + or req.raw_request.headers.get("x-agentrun-agent-id") + or req.raw_request.headers.get("X-Agentrun-Agent-Id") + ) if agent_id: return agent_id @@ -407,8 +404,12 @@ async def wrap_invoke_agent( "content": self._extract_message_content(msg.content), }) - # 收集 Agent 响应 + # 收集 Agent 响应(包括文本和工具调用) agent_response_content = "" + tool_calls: Dict[str, Dict[str, Any]] = ( + {} + ) # tool_call_id -> tool_call_info + tool_results: List[Dict[str, Any]] = [] # 工具执行结果列表 try: # 流式处理 Agent 响应 @@ -420,17 +421,80 @@ async def wrap_invoke_agent( if event.event == EventType.TEXT and "delta" in event.data: agent_response_content += event.data["delta"] + # 收集工具调用信息 + elif event.event == EventType.TOOL_CALL: + # 完整的工具调用 + tool_id = event.data.get("id", "") + if tool_id: + tool_calls[tool_id] = { + "id": tool_id, + "type": "function", + "function": { + "name": event.data.get("name", ""), + "arguments": event.data.get("args", ""), + }, + } + + elif event.event == EventType.TOOL_CALL_CHUNK: + # 工具调用片段(流式场景) + tool_id = event.data.get("id", "") + if tool_id: + if tool_id not in tool_calls: + tool_calls[tool_id] = { + "id": tool_id, + "type": "function", + "function": { + "name": event.data.get("name", ""), + "arguments": "", + }, + } + # 累积参数片段 + if "args_delta" in event.data: + tool_calls[tool_id]["function"][ + "arguments" + ] += event.data["args_delta"] + + # 收集工具执行结果 + elif event.event == EventType.TOOL_RESULT: + tool_id = event.data.get("id", "") + if tool_id: + tool_results.append({ + "role": "tool", + "tool_call_id": tool_id, + "content": str(event.data.get("result", "")), + }) + # 透传事件 yield event # 保存完整的对话轮次(输入 + 输出) - if agent_response_content: + # 只有当有文本内容或工具调用时才保存 + if agent_response_content or tool_calls or tool_results: try: - # 将助手响应添加到消息列表 - output_messages = input_messages + [{ + # 构建助手响应消息 + assistant_message: Dict[str, Any] = { "role": "assistant", - "content": agent_response_content, - }] + } + + # 添加文本内容(如果有) + if agent_response_content: + assistant_message["content"] = agent_response_content + else: + # OpenAI 格式要求:如果有 tool_calls,content 可以为 null + assistant_message["content"] = None + + # 添加工具调用(如果有) + if tool_calls: + assistant_message["tool_calls"] = list( + tool_calls.values() + ) + + # 构建完整的消息列表 + output_messages = input_messages + [assistant_message] + + # 添加工具执行结果(如果有) + if tool_results: + output_messages.extend(tool_results) # 将完整的对话历史存储为一条消息 # content 字段存储 JSON 格式的消息列表 @@ -446,8 +510,10 @@ async def wrap_invoke_agent( await memory_store.update_session(session) logger.debug( - f"Saved conversation: {len(output_messages)} messages, " - f"response length: {len(agent_response_content)} chars" + f"Saved conversation: {len(output_messages)} messages," + f" text length: {len(agent_response_content)} chars," + f" tool_calls: {len(tool_calls)}, tool_results:" + f" {len(tool_results)}" ) except Exception as e: logger.error( diff --git a/agentrun/server/server.py b/agentrun/server/server.py index cfacef6..cb5eb67 100644 --- a/agentrun/server/server.py +++ b/agentrun/server/server.py @@ -83,6 +83,14 @@ class AgentRunServer: ... invoke_agent=invoke_agent, ... config=ServerConfig(cors_origins=["http://localhost:3000"]) ... ) + + Example (启用会话历史记录): + >>> server = AgentRunServer( + ... invoke_agent=invoke_agent, + ... memory_collection_name="my-memory-collection" + ... ) + >>> server.start(port=8000) + # 会话历史将自动保存到 TableStore """ def __init__( @@ -90,6 +98,7 @@ def __init__( invoke_agent: InvokeAgentHandler, protocols: Optional[List[ProtocolHandler]] = None, config: Optional[ServerConfig] = None, + memory_collection_name: Optional[str] = None, ): """初始化 AgentRun Server @@ -107,8 +116,20 @@ def __init__( - cors_origins: CORS 允许的源列表 - openai: OpenAI 协议配置 - agui: AG-UI 协议配置 + + memory_collection_name: MemoryCollection 名称(可选) + - 如果提供,将自动启用会话历史记录功能 + - 会话历史将保存到指定的 MemoryCollection 中 """ self.app = FastAPI(title="AgentRun Server") + + # 如果启用了 memory,包装 invoke_agent + if memory_collection_name: + invoke_agent = self._wrap_with_memory( + invoke_agent, + memory_collection_name, + ) + self.agent_invoker = AgentInvoker(invoke_agent) # 配置 CORS @@ -124,6 +145,39 @@ def __init__( # 挂载所有协议的 Router self._mount_protocols(protocols) + def _wrap_with_memory( + self, + invoke_agent: InvokeAgentHandler, + memory_collection_name: str, + ) -> InvokeAgentHandler: + """使用 MemoryConversation 包装 invoke_agent + + Args: + invoke_agent: 原始的 invoke_agent 函数 + memory_collection_name: MemoryCollection 名称 + + Returns: + 包装后的 invoke_agent 函数 + """ + from agentrun.memory_collection import MemoryConversation + + # 创建 MemoryConversation 实例 + memory = MemoryConversation( + memory_collection_name=memory_collection_name, + ) + + logger.info( + "Memory integration enabled for collection:" + f" {memory_collection_name}" + ) + + # 包装 invoke_agent + async def wrapped_invoke_agent(request: Any): + async for event in memory.wrap_invoke_agent(request, invoke_agent): + yield event + + return wrapped_invoke_agent + def _setup_cors(self, cors_origins: Optional[Sequence[str]] = None): """配置 CORS 中间件 diff --git a/examples/server_with_memory.py b/examples/server_with_memory.py index 0d6f025..af02695 100644 --- a/examples/server_with_memory.py +++ b/examples/server_with_memory.py @@ -32,7 +32,7 @@ # 配置参数 MODEL_SERVICE = os.getenv("MODEL_SERVICE", "qwen3-max") MODEL_NAME = os.getenv("MODEL_NAME", "qwen3-max") -SANDBOX_NAME = os.getenv("SANDBOX_NAME", "") +SANDBOX_NAME = os.getenv("SANDBOX_NAME", "sandbox-browser-BmUyyD") MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "mem-ots0129") # 创建 Agent @@ -40,10 +40,20 @@ # 使用 AgentRun 注册的模型 model=model(MODEL_SERVICE, model=MODEL_NAME), system_prompt=""" -你是一个诗人,根据用户输入内容写一个20字以内的诗文 +你是 AgentRun 的 AI 助手,可以通过网络搜索帮助用户解决问题 + + +你的工作流程如下 +- 当用户向你提问概念性问题时,不要直接回答,而是先进行网络搜索 +- 使用 Browser 工具打开百度搜索。如果要搜索 AgentRun,对应的搜索链接为: `https://www.baidu.com/s?ie=utf-8&wd=agentrun`。为了节省 token 使用,不要使用 `snapshot` 获取完整页面内容,而是通过 `evaluate` 获取你需要的部分 +- 获取百度搜索的结果,根据相关性分别打开子页面获取内容 + - 如果子页面的相关度较低,则可以直接忽略 + - 如果子页面的相关度较高,则将其记录为可参考的资料,记录页面标题和实时的 url +- 当你获得至少 3 条网络信息后,可以结束搜索,并根据搜索到的结果回答用户的问题。 +- 如果某一部分回答引用了网络的信息,需要进行标注,并在回答的最后给出跳转链接 """, # 使用 AgentRun 的 Sandbox 工具 - # tools=[*sandbox_toolset(SANDBOX_NAME, template_type=TemplateType.BROWSER)], + tools=[*sandbox_toolset(SANDBOX_NAME, template_type=TemplateType.BROWSER)], ) # 初始化 Memory Integration @@ -51,37 +61,31 @@ async def invoke_agent(req: AgentRequest): - """Agent 调用函数,集成了记忆存储功能""" + """Agent 调用函数""" try: converter = AgentRunConverter() - - # 定义原始的 agent 处理函数 - async def agent_handler(request: AgentRequest): - result = agent.astream_events( - { - "messages": [ - {"role": msg.role, "content": msg.content} - for msg in request.messages - ] - }, - config={"recursion_limit": 1000}, - ) - async for event in result: - for agentrun_event in converter.convert(event): - yield agentrun_event - - # 使用 MemoryIntegration 包装,自动存储对话历史 - async for event in memory.wrap_invoke_agent(req, agent_handler): - yield event - + result = agent.astream_events( + { + "messages": [ + {"role": msg.role, "content": msg.content} + for msg in req.messages + ] + }, + config={"recursion_limit": 1000}, + ) + async for event in result: + for agentrun_event in converter.convert(event): + yield agentrun_event except Exception as e: - print(f"Error in invoke_agent: {e}") + print(e) raise Exception("Internal Error") # 创建并启动 Server if __name__ == "__main__": - server = AgentRunServer(invoke_agent=invoke_agent) + server = AgentRunServer( + invoke_agent=invoke_agent, memory_collection_name=MEMORY_COLLECTION_NAME + ) print(f"Server starting with memory collection: {MEMORY_COLLECTION_NAME}") print("Memory will be automatically saved to TableStore") server.start(port=9000) diff --git a/tests/unittests/memory_collection/test_memory_conversation.py b/tests/unittests/memory_collection/test_memory_conversation.py index 9624898..812564c 100644 --- a/tests/unittests/memory_collection/test_memory_conversation.py +++ b/tests/unittests/memory_collection/test_memory_conversation.py @@ -110,17 +110,6 @@ def test_default_session_id_extractor(self): session_id = MemoryConversation._default_session_id_extractor(request) assert session_id == "session456" - def test_default_session_id_extractor_from_message(self): - """Test session_id extraction from message ID""" - request = AgentRequest( - messages=[ - Message(id="msg123", role=MessageRole.USER, content="Hello") - ] - ) - - session_id = MemoryConversation._default_session_id_extractor(request) - assert session_id == "session_msg123" - def test_default_session_id_extractor_generate(self): """Test session_id generation""" request = AgentRequest(messages=[]) @@ -268,3 +257,184 @@ async def mock_agent(request: AgentRequest): results.append(event) assert results == ["Response"] + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_with_tool_calls( + self, mock_memory_collection, mock_memory_store, mock_ots_client + ): + """Test that tool calls and results are saved correctly""" + from agentrun.server.model import AgentEvent, EventType + + # Create MemoryConversation + memory = MemoryConversation(memory_collection_name="test-memory") + + # Mock agent handler with tool calls + async def mock_agent(request: AgentRequest): + # First yield some text + yield "Let me search for that..." + + # Then yield a tool call + yield AgentEvent( + event=EventType.TOOL_CALL, + data={ + "id": "call_123", + "name": "search_tool", + "args": '{"query": "weather"}', + }, + ) + + # Then yield tool result + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "call_123", + "result": "Sunny, 25°C", + }, + ) + + # Finally yield more text + yield "Based on the search, it's sunny today." + + # Create request with raw_request mock + request = AgentRequest( + messages=[ + Message(role=MessageRole.USER, content="What's the weather?") + ] + ) + request.raw_request = MagicMock() + request.raw_request.headers = {"X-User-ID": "user123"} + + # Wrap and collect results + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + # Verify all events were passed through + assert len(results) == 4 + assert results[0] == "Let me search for that..." + assert results[3] == "Based on the search, it's sunny today." + + # Verify message was saved with tool calls + assert mock_memory_store.put_message.called + saved_message = mock_memory_store.put_message.call_args[0][0] + + # Parse the saved content + import json + + saved_content = json.loads(saved_message.content) + + # Should have: user message + assistant message + tool result + assert len(saved_content) >= 2 + + # Check assistant message has both content and tool_calls + assistant_msg = saved_content[1] + assert assistant_msg["role"] == "assistant" + assert "Let me search for that..." in assistant_msg["content"] + assert ( + "Based on the search, it's sunny today." in assistant_msg["content"] + ) + assert "tool_calls" in assistant_msg + assert len(assistant_msg["tool_calls"]) == 1 + + # Check tool call structure + tool_call = assistant_msg["tool_calls"][0] + assert tool_call["id"] == "call_123" + assert tool_call["type"] == "function" + assert tool_call["function"]["name"] == "search_tool" + assert tool_call["function"]["arguments"] == '{"query": "weather"}' + + # Check tool result + assert len(saved_content) == 3 + tool_result_msg = saved_content[2] + assert tool_result_msg["role"] == "tool" + assert tool_result_msg["tool_call_id"] == "call_123" + assert tool_result_msg["content"] == "Sunny, 25°C" + + @pytest.mark.asyncio + async def test_wrap_invoke_agent_with_tool_call_chunks( + self, mock_memory_collection, mock_memory_store, mock_ots_client + ): + """Test that streaming tool call chunks are accumulated correctly""" + from agentrun.server.model import AgentEvent, EventType + + # Create MemoryConversation + memory = MemoryConversation(memory_collection_name="test-memory") + + # Mock agent handler with streaming tool calls + async def mock_agent(request: AgentRequest): + # Yield tool call chunks (streaming scenario) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_456", + "name": "calculator", + "args_delta": '{"a":', + }, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_456", + "args_delta": ' 10, "b"', + }, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_456", + "args_delta": ": 20}", + }, + ) + + # Yield tool result + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "call_456", + "result": "30", + }, + ) + + # Create request with raw_request mock + request = AgentRequest( + messages=[ + Message(role=MessageRole.USER, content="Calculate 10 + 20") + ] + ) + request.raw_request = MagicMock() + request.raw_request.headers = {"X-User-ID": "user123"} + + # Wrap and collect results + results = [] + async for event in memory.wrap_invoke_agent(request, mock_agent): + results.append(event) + + # Verify all events were passed through + assert len(results) == 4 + + # Verify message was saved with accumulated tool call + assert mock_memory_store.put_message.called + saved_message = mock_memory_store.put_message.call_args[0][0] + + # Parse the saved content + import json + + saved_content = json.loads(saved_message.content) + + # Check assistant message has tool_calls with accumulated arguments + assistant_msg = saved_content[1] + assert assistant_msg["role"] == "assistant" + assert "tool_calls" in assistant_msg + assert len(assistant_msg["tool_calls"]) == 1 + + # Check accumulated arguments + tool_call = assistant_msg["tool_calls"][0] + assert tool_call["id"] == "call_456" + assert tool_call["function"]["name"] == "calculator" + assert tool_call["function"]["arguments"] == '{"a": 10, "b": 20}' + + # Check tool result + tool_result_msg = saved_content[2] + assert tool_result_msg["role"] == "tool" + assert tool_result_msg["tool_call_id"] == "call_456" + assert tool_result_msg["content"] == "30" From 0d879a9325aa3f8e7183506c02e5dc5ffd4e3778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Mon, 9 Feb 2026 21:59:24 +0800 Subject: [PATCH 4/4] feat(memory_collection): add MemoryConversation class for dialogue history management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I020c9d634d3db70e1b1fedc1f5c25e4407f63b4e Co-developed-by: Cursor Signed-off-by: 久氢 --- .../memory_collection/memory_conversation.py | 65 +++++++++---------- .../test_memory_conversation.py | 29 ++++++++- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/agentrun/memory_collection/memory_conversation.py b/agentrun/memory_collection/memory_conversation.py index 6035288..64ab94f 100644 --- a/agentrun/memory_collection/memory_conversation.py +++ b/agentrun/memory_collection/memory_conversation.py @@ -108,20 +108,26 @@ def _default_session_id_extractor(req: Any) -> str: """默认的 session_id 提取器 优先级: - 1. X-Session-ID 请求头(支持多种格式) - 2. 生成新的 UUID + 1. X-Conversation-ID 请求头(支持多种格式) + 2. sessionId 查询参数 + 3. 生成新的 UUID """ if req.raw_request: # 从请求头获取(兼容多种格式) - # 支持:X-AgentRun-Session-ID, x-agentrun-session-id, X-Agentrun-Session-Id + # 支持:X-AgentRun-Conversation-ID, x-agentrun-conversation-id, X-Agentrun-Conversation-Id session_id = ( - req.raw_request.headers.get("X-AgentRun-Session-ID") - or req.raw_request.headers.get("x-agentrun-session-id") - or req.raw_request.headers.get("X-Agentrun-Session-Id") + req.raw_request.headers.get("X-AgentRun-Conversation-ID") + or req.raw_request.headers.get("x-agentrun-conversation-id") + or req.raw_request.headers.get("X-Agentrun-Conversation-Id") ) if session_id: return session_id + # 从查询参数获取 + session_id = req.raw_request.query_params.get("sessionId") + if session_id: + return session_id + # 生成新的 session_id return f"session_{uuid.uuid4().hex[:16]}" @@ -131,12 +137,11 @@ def _default_agent_id_extractor(req: Any) -> str: 优先级: 1. X-Agent-ID 请求头(支持多种格式) - 2. 从 URL 路径中提取 /agent-runtimes/{agent_id}/... 格式 - 3. 默认值 "default_agent" + 2. 默认值 "default_agent" """ if req.raw_request: # 从请求头获取(兼容多种格式) - # 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id, X-Agent-ID, x-agent-id + # 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id agent_id = ( req.raw_request.headers.get("X-AgentRun-Agent-ID") or req.raw_request.headers.get("x-agentrun-agent-id") @@ -145,25 +150,6 @@ def _default_agent_id_extractor(req: Any) -> str: if agent_id: return agent_id - # 从 URL 路径中提取 - # 例如:/agent-runtimes/agent-quick-xFGD/invoke -> agent-quick-xFGD - try: - path = ( - req.raw_request.url.path - if hasattr(req.raw_request.url, "path") - else str(req.raw_request.url) - ) - if "/agent-runtimes/" in path: - # 提取 /agent-runtimes/ 后面的部分 - parts = path.split("/agent-runtimes/", 1) - if len(parts) > 1: - # 获取下一个路径段 - agent_part = parts[1].split("/")[0] - if agent_part: - return agent_part - except Exception: - pass - return "default_agent" async def _get_memory_store(self): @@ -193,12 +179,19 @@ async def _get_memory_store(self): ots_config = await self._get_ots_config_from_memory_collection() # 创建 AsyncOTSClient - self._ots_client = tablestore.AsyncOTSClient( - end_point=ots_config["endpoint"], - access_key_id=ots_config["access_key_id"], - access_key_secret=ots_config["access_key_secret"], - instance_name=ots_config["instance_name"], - ) + # 支持使用 STS 临时凭证访问 TableStore + client_kwargs = { + "end_point": ots_config["endpoint"], + "access_key_id": ots_config["access_key_id"], + "access_key_secret": ots_config["access_key_secret"], + "instance_name": ots_config["instance_name"], + } + + # 如果提供了 security_token,则添加到参数中(支持 STS 临时凭证) + if ots_config.get("security_token"): + client_kwargs["sts_token"] = ots_config["security_token"] + + self._ots_client = tablestore.AsyncOTSClient(**client_kwargs) # 配置会话表的二级索引元数据字段 # agent_id 字段用于标识会话所属的 Agent @@ -258,6 +251,7 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]: - endpoint: OTS endpoint - access_key_id: 访问密钥 ID - access_key_secret: 访问密钥 Secret + - security_token: STS 安全令牌(可选,用于临时凭证) - instance_name: OTS 实例名称 """ from agentrun.memory_collection import MemoryCollection @@ -309,6 +303,9 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]: "instance_name": vs_config.instance_name or "", "access_key_id": self.config.get_access_key_id(), "access_key_secret": self.config.get_access_key_secret(), + "security_token": ( + self.config.get_security_token() + ), # 支持 STS 临时凭证 } return ots_config diff --git a/tests/unittests/memory_collection/test_memory_conversation.py b/tests/unittests/memory_collection/test_memory_conversation.py index 812564c..20b2592 100644 --- a/tests/unittests/memory_collection/test_memory_conversation.py +++ b/tests/unittests/memory_collection/test_memory_conversation.py @@ -92,10 +92,14 @@ def test_default_user_id_extractor_fallback(self): def test_default_session_id_extractor(self): """Test default session_id extraction""" - # Test with X-Session-ID header + # Test with X-AgentRun-Conversation-ID header mock_req = Mock() mock_headers = Mock() - mock_headers.get = Mock(return_value="session456") + mock_headers.get = Mock( + side_effect=lambda k: { + "X-AgentRun-Conversation-ID": "session456" + }.get(k) + ) mock_query = Mock() mock_query.get = Mock(return_value=None) @@ -110,6 +114,27 @@ def test_default_session_id_extractor(self): session_id = MemoryConversation._default_session_id_extractor(request) assert session_id == "session456" + def test_default_session_id_extractor_from_query(self): + """Test session_id extraction from query parameter""" + mock_req = Mock() + mock_headers = Mock() + mock_headers.get = Mock(return_value=None) + mock_query = Mock() + mock_query.get = Mock( + side_effect=lambda k: {"sessionId": "query_session789"}.get(k) + ) + + mock_req.headers = mock_headers + mock_req.query_params = mock_query + + request = AgentRequest.model_construct( + messages=[], + raw_request=mock_req, + ) + + session_id = MemoryConversation._default_session_id_extractor(request) + assert session_id == "query_session789" + def test_default_session_id_extractor_generate(self): """Test session_id generation""" request = AgentRequest(messages=[])