From 64951392f20d1b01431be0a5026dc7515f257048 Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Sun, 22 Feb 2026 16:12:58 +0800 Subject: [PATCH] feat(client): add multi-tenant support for client --- .../reverse_mcp/server_with_reverse_mcp.py | 229 ++++++++++++------ 1 file changed, 158 insertions(+), 71 deletions(-) diff --git a/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py b/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py index da1480d9..9de83556 100644 --- a/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py +++ b/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py @@ -14,15 +14,18 @@ import asyncio import json +import time import uuid +import threading +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Optional from fastapi import FastAPI, HTTPException, Request, Response, WebSocket from fastapi.responses import StreamingResponse -from google.adk.agents.run_config import RunConfig, StreamingMode +from google.adk.agents.run_config import StreamingMode from google.adk.artifacts import InMemoryArtifactService from google.adk.cli.adk_web_server import RunAgentRequest -from google.adk.runners import Runner as GoogleRunner +from google.adk.runners import Runner as GoogleRunner, RunConfig from google.adk.sessions import InMemorySessionService, Session from google.adk.tools.mcp_tool.mcp_session_manager import ( StreamableHTTPConnectionParams, @@ -34,6 +37,7 @@ from veadk import Runner from veadk.utils.logger import get_logger + if TYPE_CHECKING: from veadk import Agent @@ -48,47 +52,121 @@ class ExtraRoute(BaseModel): methods: list[str] -class WebsocketSessionManager: - def __init__(self): - # ws id -> ws instance - self.connections: dict[str, WebSocket] = {} +@dataclass +class ClientResource: + websocket: WebSocket + agent: "Agent" + session_service: InMemorySessionService + artifact_service: InMemoryArtifactService + pending_requests: dict[str, asyncio.Future] = field(default_factory=dict) + last_active_time: float = field(default_factory=time.time) + + def update_activity(self): + self.last_active_time = time.time() - # ws id -> msg id -> ret - self.pendings: dict[str, dict[str, asyncio.Future]] = {} - async def call_mcp_http(self, ws_id: str, request: dict): +class ResourceManager: + def __init__(self, timeout_seconds: int = 3600): + self._lock: threading.Lock = threading.Lock() + self.resources: dict[str, ClientResource] = {} + self.timeout_seconds = timeout_seconds + self.cleanup_task: Optional[asyncio.Task] = None + + def register( + self, + client_id: str, + websocket: WebSocket, + agent: "Agent", + session_service: InMemorySessionService, + artifact_service: InMemoryArtifactService, + ): + with self._lock: + self.resources[client_id] = ClientResource( + websocket=websocket, + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + ) + logger.info(f"client {client_id} registered") + + def get(self, client_id: str) -> Optional[ClientResource]: + with self._lock: + logger.info(f"get {client_id}") + resource = self.resources.get(client_id) + if resource: + resource.update_activity() + return resource + + async def remove(self, client_id: str): + if client_id in self.resources: + resource = self.resources.pop(client_id) + try: + await resource.websocket.close() + for fut in resource.pending_requests.values(): + if not fut.done(): + fut.cancel() + except Exception as e: + logger.warning( + f"client {client_id} resource websocket close error: {e}" + ) + pass + + async def start_cleanup_loop(self): + logger.info("ResourceManager: active cleanup loop") + while True: + await asyncio.sleep(60) # Check every minute + logger.debug("cleanup loop running...") + now = time.time() + to_remove = [] + for client_id, resource in self.resources.items(): + logger.debug( + f"check {client_id}, last_active_time={resource.last_active_time}, timeout={self.timeout_seconds}" + ) + if now - resource.last_active_time > self.timeout_seconds: + to_remove.append(client_id) + + for client_id in to_remove: + logger.info(f"Removing inactive client {client_id}") + await self.remove(client_id) + + def start(self): + self.cleanup_task = asyncio.create_task(self.start_cleanup_loop()) + + def stop(self): + if self.cleanup_task: + self.cleanup_task.cancel() + + async def call_mcp_http(self, client_id: str, request: dict): """Forward MCP request to client.""" - try: - ws = self.connections[ws_id] - except KeyError: - logger.error(f"Websocket {ws_id} not found") + resource = self.get(client_id) + if not resource: + logger.error(f"Client {client_id} not found") return b"" - msg = {} - - msg["id"] = str(uuid.uuid4()) - msg["type"] = "http_request" - msg["payload"] = request + ws = resource.websocket + msg = {"id": str(uuid.uuid4()), "type": "http_request", "payload": request} fut = asyncio.get_event_loop().create_future() - if ws_id not in self.pendings: - self.pendings[ws_id] = {} - - self.pendings[ws_id][msg["id"]] = fut + resource.pending_requests[msg["id"]] = fut await ws.send_text(json.dumps(msg)) return await fut - async def handle_ws_message(self, ws_id: str, raw: str): + async def handle_ws_message(self, client_id: str, raw: str): + resource = self.get(client_id) + if not resource: + return + msg = json.loads(raw) if msg.get("type") != "http_response": return req_id = msg["id"] - fut = self.pendings[ws_id].pop(req_id, None) + fut = resource.pending_requests.pop(req_id, None) if fut: fut.set_result(msg) + # todo : 异常ID处理 class ServerWithReverseMCP: @@ -102,27 +180,25 @@ def __init__( extra_routes: list[ExtraRoute] | None = None, ): self.agent = agent - self.host = host self.port = port - self.extra_routes = extra_routes - self.app = FastAPI( - openapi_url=None, - docs_url=None, - redoc_url=None, - swagger_ui_oauth2_redirect_url=None, - ) + self.app = FastAPI() self.artifact_service = InMemoryArtifactService() + self.resource_manager = ResourceManager() # build routes for self.app self.build() - self.ws_session_mgr = WebsocketSessionManager() - self.ws_agent_mgr: dict[str, "Agent"] = {} - self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {} + @self.app.on_event("startup") + async def startup_event(): + self.resource_manager.start() + + @self.app.on_event("shutdown") + async def shutdown_event(): + self.resource_manager.stop() def build(self): logger.info("Build routes for server with reverse mcp") @@ -149,9 +225,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse: session_id = payload.session_id prompt = payload.prompt - agent = self.ws_agent_mgr[payload.websocket_id] + resource = self.resource_manager.get(payload.websocket_id) + if not resource: + raise HTTPException( + status_code=404, detail=f"Client {payload.websocket_id} not found" + ) + agent = resource.agent - runner = Runner(app_name=payload.app_name, agent=agent) + runner = Runner( + app_name=payload.app_name, + agent=agent, + session_service=resource.session_service, + ) response = await runner.run( messages=[prompt], user_id=user_id, @@ -160,6 +245,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse: return InvokeResponse(response=response) + @self.app.delete("/management/clients/{client_id}") + async def delete_client(client_id: str): + """Manually remove a client resource.""" + await self.resource_manager.remove(client_id) + return {"status": "success", "client_id": client_id} + # build websocket endpoint @self.app.websocket("/ws") async def ws_endpoint(ws: WebSocket): @@ -179,15 +270,10 @@ async def ws_endpoint(ws: WebSocket): filters = [t.strip() for t in filters_str.split(",") if t.strip()] logger.info(f"Register websocket {client_id} to session manager.") - self.ws_session_mgr.connections[client_id] = ws logger.info(f"Fork agent for websocket {client_id}") agent = self.agent.clone() - logger.info( - f"clone agent \n model_name={agent.model_name}\n instruction={agent.instruction}\n" - ) - # Mount MCPToolset when creating agent mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp" mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id} @@ -201,10 +287,18 @@ async def ws_endpoint(ws: WebSocket): tool_filter=filters, ) ) - self.ws_agent_mgr[client_id] = agent logger.info(f"Create session service for websocket {client_id}") - self.ws_session_service_mgr[client_id] = InMemorySessionService() + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + + self.resource_manager.register( + client_id=client_id, + websocket=ws, + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + ) await ws.accept() logger.info(f"Websocket {client_id} connected") @@ -213,7 +307,7 @@ async def ws_endpoint(ws: WebSocket): while True: raw = await ws.receive_text() logger.debug(f"ws.receive_text() = {raw}") - await self.ws_session_mgr.handle_ws_message(client_id, raw) + await self.resource_manager.handle_ws_message(client_id, raw) except Exception as e: logger.warning(f"client {client_id} web socket connection closed: {e}") @@ -227,12 +321,12 @@ class RunAgentRequestWithWsId(RunAgentRequest): def _get_session_service(websocket_id: str) -> InMemorySessionService: """Get session service for the websocket client.""" - if websocket_id not in self.ws_session_service_mgr: + resource = self.resource_manager.get(websocket_id) + if not resource: raise HTTPException( - status_code=404, - detail=f"WebSocket client {websocket_id} not found", + status_code=404, detail=f"WebSocket client {websocket_id} not found" ) - return self.ws_session_service_mgr[websocket_id] + return resource.session_service @self.app.post( "/apps/{app_name}/users/{user_id}/sessions", @@ -291,11 +385,18 @@ async def create_session_with_id( return session @self.app.post("/run_sse") - async def run_agent_sse( - req: RunAgentRequestWithWsId, - ) -> StreamingResponse: + async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse: """Run agent with SSE streaming.""" - session_service = _get_session_service(req.websocket_id) + resource = self.resource_manager.get(req.websocket_id) + if not resource: + raise HTTPException( + status_code=404, + detail=f"WebSocket client {req.websocket_id} not found", + ) + + session_service = resource.session_service + agent = resource.agent + logger.debug(f"Using agent from websocket {req.websocket_id}") # Get session session = await session_service.get_session( @@ -306,16 +407,6 @@ async def run_agent_sse( if not session: raise HTTPException(status_code=404, detail="Session not found") - # Get agent for this websocket - if req.websocket_id in self.ws_agent_mgr: - agent = self.ws_agent_mgr[req.websocket_id] - logger.debug(f"Using agent from websocket {req.websocket_id}") - else: - raise HTTPException( - status_code=404, - detail=f"WebSocket client {req.websocket_id} not found", - ) - # Create runner runner = GoogleRunner( agent=agent, @@ -354,10 +445,7 @@ async def event_generator(): content_event.actions.artifact_delta = {} artifact_event = event.model_copy(deep=True) artifact_event.content = None - events_to_stream = [ - content_event, - artifact_event, - ] + events_to_stream = [content_event, artifact_event] for event_to_stream in events_to_stream: sse_event = event_to_stream.model_dump_json( @@ -367,7 +455,7 @@ async def event_generator(): yield f"data: {sse_event}\n\n" except Exception as e: logger.exception(f"Error in event_generator: {e}") - yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n" + yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( event_generator(), @@ -391,8 +479,7 @@ async def mcp_proxy(path: str, request: Request): if not client_id: return Response("client id not found", status_code=400) - ws = self.ws_session_mgr.connections.get(client_id) - if not ws: + if not self.resource_manager.get(client_id): return Response("websocket `client_id` not connected", status_code=503) body = await request.body() @@ -409,7 +496,7 @@ async def mcp_proxy(path: str, request: Request): logger.debug(f"[Reverse mcp proxy] Request from agent: {payload}") - resp = await self.ws_session_mgr.call_mcp_http(client_id, payload) + resp = await self.resource_manager.call_mcp_http(client_id, payload) logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")