Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 158 additions & 71 deletions veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

import asyncio
import json
import time
import uuid
import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional

from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
from fastapi.responses import StreamingResponse
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.agents.run_config import StreamingMode
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.adk_web_server import RunAgentRequest
from google.adk.runners import Runner as GoogleRunner
from google.adk.runners import Runner as GoogleRunner, RunConfig
from google.adk.sessions import InMemorySessionService, Session
from google.adk.tools.mcp_tool.mcp_session_manager import (
StreamableHTTPConnectionParams,
Expand All @@ -34,6 +37,7 @@
from veadk import Runner
from veadk.utils.logger import get_logger


if TYPE_CHECKING:
from veadk import Agent

Expand All @@ -48,47 +52,121 @@ class ExtraRoute(BaseModel):
methods: list[str]


class WebsocketSessionManager:
def __init__(self):
# ws id -> ws instance
self.connections: dict[str, WebSocket] = {}
@dataclass
class ClientResource:
websocket: WebSocket
agent: "Agent"
session_service: InMemorySessionService
artifact_service: InMemoryArtifactService
pending_requests: dict[str, asyncio.Future] = field(default_factory=dict)
last_active_time: float = field(default_factory=time.time)

def update_activity(self):
self.last_active_time = time.time()

# ws id -> msg id -> ret
self.pendings: dict[str, dict[str, asyncio.Future]] = {}

async def call_mcp_http(self, ws_id: str, request: dict):
class ResourceManager:
def __init__(self, timeout_seconds: int = 3600):
self._lock: threading.Lock = threading.Lock()
self.resources: dict[str, ClientResource] = {}
self.timeout_seconds = timeout_seconds
self.cleanup_task: Optional[asyncio.Task] = None

def register(
self,
client_id: str,
websocket: WebSocket,
agent: "Agent",
session_service: InMemorySessionService,
artifact_service: InMemoryArtifactService,
):
with self._lock:
self.resources[client_id] = ClientResource(
websocket=websocket,
agent=agent,
session_service=session_service,
artifact_service=artifact_service,
)
logger.info(f"client {client_id} registered")

def get(self, client_id: str) -> Optional[ClientResource]:
with self._lock:
logger.info(f"get {client_id}")
resource = self.resources.get(client_id)
if resource:
resource.update_activity()
return resource

async def remove(self, client_id: str):
if client_id in self.resources:
resource = self.resources.pop(client_id)
try:
await resource.websocket.close()
for fut in resource.pending_requests.values():
if not fut.done():
fut.cancel()
except Exception as e:
logger.warning(
f"client {client_id} resource websocket close error: {e}"
)
pass

async def start_cleanup_loop(self):
logger.info("ResourceManager: active cleanup loop")
while True:
await asyncio.sleep(60) # Check every minute
logger.debug("cleanup loop running...")
now = time.time()
to_remove = []
for client_id, resource in self.resources.items():
logger.debug(
f"check {client_id}, last_active_time={resource.last_active_time}, timeout={self.timeout_seconds}"
)
if now - resource.last_active_time > self.timeout_seconds:
to_remove.append(client_id)

for client_id in to_remove:
logger.info(f"Removing inactive client {client_id}")
await self.remove(client_id)

def start(self):
self.cleanup_task = asyncio.create_task(self.start_cleanup_loop())

def stop(self):
if self.cleanup_task:
self.cleanup_task.cancel()

async def call_mcp_http(self, client_id: str, request: dict):
"""Forward MCP request to client."""
try:
ws = self.connections[ws_id]
except KeyError:
logger.error(f"Websocket {ws_id} not found")
resource = self.get(client_id)
if not resource:
logger.error(f"Client {client_id} not found")
return b""

msg = {}

msg["id"] = str(uuid.uuid4())
msg["type"] = "http_request"
msg["payload"] = request
ws = resource.websocket
msg = {"id": str(uuid.uuid4()), "type": "http_request", "payload": request}

fut = asyncio.get_event_loop().create_future()

if ws_id not in self.pendings:
self.pendings[ws_id] = {}

self.pendings[ws_id][msg["id"]] = fut
resource.pending_requests[msg["id"]] = fut

await ws.send_text(json.dumps(msg))
return await fut

async def handle_ws_message(self, ws_id: str, raw: str):
async def handle_ws_message(self, client_id: str, raw: str):
resource = self.get(client_id)
if not resource:
return

msg = json.loads(raw)
if msg.get("type") != "http_response":
return

req_id = msg["id"]
fut = self.pendings[ws_id].pop(req_id, None)
fut = resource.pending_requests.pop(req_id, None)
if fut:
fut.set_result(msg)
# todo : 异常ID处理


class ServerWithReverseMCP:
Expand All @@ -102,27 +180,25 @@ def __init__(
extra_routes: list[ExtraRoute] | None = None,
):
self.agent = agent

self.host = host
self.port = port

self.extra_routes = extra_routes

self.app = FastAPI(
openapi_url=None,
docs_url=None,
redoc_url=None,
swagger_ui_oauth2_redirect_url=None,
)
self.app = FastAPI()

self.artifact_service = InMemoryArtifactService()
self.resource_manager = ResourceManager()

# build routes for self.app
self.build()

self.ws_session_mgr = WebsocketSessionManager()
self.ws_agent_mgr: dict[str, "Agent"] = {}
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}
@self.app.on_event("startup")
async def startup_event():
self.resource_manager.start()

@self.app.on_event("shutdown")
async def shutdown_event():
self.resource_manager.stop()

def build(self):
logger.info("Build routes for server with reverse mcp")
Expand All @@ -149,9 +225,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
session_id = payload.session_id
prompt = payload.prompt

agent = self.ws_agent_mgr[payload.websocket_id]
resource = self.resource_manager.get(payload.websocket_id)
if not resource:
raise HTTPException(
status_code=404, detail=f"Client {payload.websocket_id} not found"
)
agent = resource.agent

runner = Runner(app_name=payload.app_name, agent=agent)
runner = Runner(
app_name=payload.app_name,
agent=agent,
session_service=resource.session_service,
)
response = await runner.run(
messages=[prompt],
user_id=user_id,
Expand All @@ -160,6 +245,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:

return InvokeResponse(response=response)

@self.app.delete("/management/clients/{client_id}")
async def delete_client(client_id: str):
"""Manually remove a client resource."""
await self.resource_manager.remove(client_id)
return {"status": "success", "client_id": client_id}

# build websocket endpoint
@self.app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
Expand All @@ -179,15 +270,10 @@ async def ws_endpoint(ws: WebSocket):
filters = [t.strip() for t in filters_str.split(",") if t.strip()]

logger.info(f"Register websocket {client_id} to session manager.")
self.ws_session_mgr.connections[client_id] = ws

logger.info(f"Fork agent for websocket {client_id}")
agent = self.agent.clone()

logger.info(
f"clone agent \n model_name={agent.model_name}\n instruction={agent.instruction}\n"
)

# Mount MCPToolset when creating agent
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
Expand All @@ -201,10 +287,18 @@ async def ws_endpoint(ws: WebSocket):
tool_filter=filters,
)
)
self.ws_agent_mgr[client_id] = agent

logger.info(f"Create session service for websocket {client_id}")
self.ws_session_service_mgr[client_id] = InMemorySessionService()
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()

self.resource_manager.register(
client_id=client_id,
websocket=ws,
agent=agent,
session_service=session_service,
artifact_service=artifact_service,
)

await ws.accept()
logger.info(f"Websocket {client_id} connected")
Expand All @@ -213,7 +307,7 @@ async def ws_endpoint(ws: WebSocket):
while True:
raw = await ws.receive_text()
logger.debug(f"ws.receive_text() = {raw}")
await self.ws_session_mgr.handle_ws_message(client_id, raw)
await self.resource_manager.handle_ws_message(client_id, raw)
except Exception as e:
logger.warning(f"client {client_id} web socket connection closed: {e}")

Expand All @@ -227,12 +321,12 @@ class RunAgentRequestWithWsId(RunAgentRequest):

def _get_session_service(websocket_id: str) -> InMemorySessionService:
"""Get session service for the websocket client."""
if websocket_id not in self.ws_session_service_mgr:
resource = self.resource_manager.get(websocket_id)
if not resource:
raise HTTPException(
status_code=404,
detail=f"WebSocket client {websocket_id} not found",
status_code=404, detail=f"WebSocket client {websocket_id} not found"
)
return self.ws_session_service_mgr[websocket_id]
return resource.session_service

@self.app.post(
"/apps/{app_name}/users/{user_id}/sessions",
Expand Down Expand Up @@ -291,11 +385,18 @@ async def create_session_with_id(
return session

@self.app.post("/run_sse")
async def run_agent_sse(
req: RunAgentRequestWithWsId,
) -> StreamingResponse:
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
"""Run agent with SSE streaming."""
session_service = _get_session_service(req.websocket_id)
resource = self.resource_manager.get(req.websocket_id)
if not resource:
raise HTTPException(
status_code=404,
detail=f"WebSocket client {req.websocket_id} not found",
)

session_service = resource.session_service
agent = resource.agent
logger.debug(f"Using agent from websocket {req.websocket_id}")

# Get session
session = await session_service.get_session(
Expand All @@ -306,16 +407,6 @@ async def run_agent_sse(
if not session:
raise HTTPException(status_code=404, detail="Session not found")

# Get agent for this websocket
if req.websocket_id in self.ws_agent_mgr:
agent = self.ws_agent_mgr[req.websocket_id]
logger.debug(f"Using agent from websocket {req.websocket_id}")
else:
raise HTTPException(
status_code=404,
detail=f"WebSocket client {req.websocket_id} not found",
)

# Create runner
runner = GoogleRunner(
agent=agent,
Expand Down Expand Up @@ -354,10 +445,7 @@ async def event_generator():
content_event.actions.artifact_delta = {}
artifact_event = event.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [
content_event,
artifact_event,
]
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
Expand All @@ -367,7 +455,7 @@ async def event_generator():
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception(f"Error in event_generator: {e}")
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
yield f"data: {json.dumps({'error': str(e)})}\n\n"

return StreamingResponse(
event_generator(),
Expand All @@ -391,8 +479,7 @@ async def mcp_proxy(path: str, request: Request):
if not client_id:
return Response("client id not found", status_code=400)

ws = self.ws_session_mgr.connections.get(client_id)
if not ws:
if not self.resource_manager.get(client_id):
return Response("websocket `client_id` not connected", status_code=503)

body = await request.body()
Expand All @@ -409,7 +496,7 @@ async def mcp_proxy(path: str, request: Request):

logger.debug(f"[Reverse mcp proxy] Request from agent: {payload}")

resp = await self.ws_session_mgr.call_mcp_http(client_id, payload)
resp = await self.resource_manager.call_mcp_http(client_id, payload)

logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")

Expand Down