Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 47 additions & 40 deletions python/copilot/generated/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,29 +1146,36 @@ def session_compaction_compact_result_to_dict(x: SessionCompactionCompactResult)
return to_class(SessionCompactionCompactResult, x)


def _timeout_kwargs(timeout: Optional[float]) -> dict:
"""Build keyword arguments for optional timeout forwarding."""
if timeout is not None:
return {"timeout": timeout}
return {}


class ModelsApi:
def __init__(self, client: "JsonRpcClient"):
self._client = client

async def list(self) -> ModelsListResult:
return ModelsListResult.from_dict(await self._client.request("models.list", {}))
async def list(self, *, timeout: Optional[float] = None) -> ModelsListResult:
return ModelsListResult.from_dict(await self._client.request("models.list", {}, **_timeout_kwargs(timeout)))


class ToolsApi:
def __init__(self, client: "JsonRpcClient"):
self._client = client

async def list(self, params: ToolsListParams) -> ToolsListResult:
async def list(self, params: ToolsListParams, *, timeout: Optional[float] = None) -> ToolsListResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict))
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict, **_timeout_kwargs(timeout)))


class AccountApi:
def __init__(self, client: "JsonRpcClient"):
self._client = client

async def get_quota(self) -> AccountGetQuotaResult:
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}))
async def get_quota(self, *, timeout: Optional[float] = None) -> AccountGetQuotaResult:
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}, **_timeout_kwargs(timeout)))


class ServerRpc:
Expand All @@ -1179,113 +1186,113 @@ def __init__(self, client: "JsonRpcClient"):
self.tools = ToolsApi(client)
self.account = AccountApi(client)

async def ping(self, params: PingParams) -> PingResult:
async def ping(self, params: PingParams, *, timeout: Optional[float] = None) -> PingResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
return PingResult.from_dict(await self._client.request("ping", params_dict))
return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout)))


class ModelApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def get_current(self) -> SessionModelGetCurrentResult:
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}))
async def get_current(self, *, timeout: Optional[float] = None) -> SessionModelGetCurrentResult:
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def switch_to(self, params: SessionModelSwitchToParams) -> SessionModelSwitchToResult:
async def switch_to(self, params: SessionModelSwitchToParams, *, timeout: Optional[float] = None) -> SessionModelSwitchToResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict))
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict, **_timeout_kwargs(timeout)))


class ModeApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def get(self) -> SessionModeGetResult:
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}))
async def get(self, *, timeout: Optional[float] = None) -> SessionModeGetResult:
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def set(self, params: SessionModeSetParams) -> SessionModeSetResult:
async def set(self, params: SessionModeSetParams, *, timeout: Optional[float] = None) -> SessionModeSetResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict))
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict, **_timeout_kwargs(timeout)))


class PlanApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def read(self) -> SessionPlanReadResult:
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}))
async def read(self, *, timeout: Optional[float] = None) -> SessionPlanReadResult:
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def update(self, params: SessionPlanUpdateParams) -> SessionPlanUpdateResult:
async def update(self, params: SessionPlanUpdateParams, *, timeout: Optional[float] = None) -> SessionPlanUpdateResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict))
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict, **_timeout_kwargs(timeout)))

async def delete(self) -> SessionPlanDeleteResult:
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}))
async def delete(self, *, timeout: Optional[float] = None) -> SessionPlanDeleteResult:
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))


class WorkspaceApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def list_files(self) -> SessionWorkspaceListFilesResult:
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}))
async def list_files(self, *, timeout: Optional[float] = None) -> SessionWorkspaceListFilesResult:
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def read_file(self, params: SessionWorkspaceReadFileParams) -> SessionWorkspaceReadFileResult:
async def read_file(self, params: SessionWorkspaceReadFileParams, *, timeout: Optional[float] = None) -> SessionWorkspaceReadFileResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict))
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict, **_timeout_kwargs(timeout)))

async def create_file(self, params: SessionWorkspaceCreateFileParams) -> SessionWorkspaceCreateFileResult:
async def create_file(self, params: SessionWorkspaceCreateFileParams, *, timeout: Optional[float] = None) -> SessionWorkspaceCreateFileResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict))
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict, **_timeout_kwargs(timeout)))


class FleetApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def start(self, params: SessionFleetStartParams) -> SessionFleetStartResult:
async def start(self, params: SessionFleetStartParams, *, timeout: Optional[float] = None) -> SessionFleetStartResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict))
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict, **_timeout_kwargs(timeout)))


class AgentApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def list(self) -> SessionAgentListResult:
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}))
async def list(self, *, timeout: Optional[float] = None) -> SessionAgentListResult:
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def get_current(self) -> SessionAgentGetCurrentResult:
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}))
async def get_current(self, *, timeout: Optional[float] = None) -> SessionAgentGetCurrentResult:
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))

async def select(self, params: SessionAgentSelectParams) -> SessionAgentSelectResult:
async def select(self, params: SessionAgentSelectParams, *, timeout: Optional[float] = None) -> SessionAgentSelectResult:
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
params_dict["sessionId"] = self._session_id
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict))
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict, **_timeout_kwargs(timeout)))

async def deselect(self) -> SessionAgentDeselectResult:
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}))
async def deselect(self, *, timeout: Optional[float] = None) -> SessionAgentDeselectResult:
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))


class CompactionApi:
def __init__(self, client: "JsonRpcClient", session_id: str):
self._client = client
self._session_id = session_id

async def compact(self) -> SessionCompactionCompactResult:
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}))
async def compact(self, *, timeout: Optional[float] = None) -> SessionCompactionCompactResult:
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))


class SessionRpc:
Expand Down
133 changes: 133 additions & 0 deletions python/test_rpc_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Tests for timeout parameter on generated RPC methods."""
from unittest.mock import AsyncMock

import pytest

from copilot.generated.rpc import (
FleetApi,
Mode,
ModeApi,
ModelsApi,
PlanApi,
SessionFleetStartParams,
SessionModeSetParams,
ToolsApi,
ToolsListParams,
)


class TestRpcTimeout:
"""Tests for timeout forwarding across all four codegen branches:
- session-scoped with params
- session-scoped without params
- server-scoped with params
- server-scoped without params
"""

# ── session-scoped, with params ──────────────────────────────────

@pytest.mark.asyncio
async def test_default_timeout_not_forwarded(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"started": True})
api = FleetApi(client, "sess-1")

await api.start(SessionFleetStartParams(prompt="go"))

client.request.assert_called_once()
_, kwargs = client.request.call_args
assert "timeout" not in kwargs

@pytest.mark.asyncio
async def test_custom_timeout_forwarded(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"started": True})
api = FleetApi(client, "sess-1")

await api.start(SessionFleetStartParams(prompt="go"), timeout=600.0)

_, kwargs = client.request.call_args
assert kwargs["timeout"] == 600.0

@pytest.mark.asyncio
async def test_timeout_on_session_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"mode": "plan"})
api = ModeApi(client, "sess-1")

await api.set(SessionModeSetParams(mode=Mode.PLAN), timeout=120.0)

_, kwargs = client.request.call_args
assert kwargs["timeout"] == 120.0

# ── session-scoped, no params ────────────────────────────────────

@pytest.mark.asyncio
async def test_timeout_on_session_no_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"exists": True})
api = PlanApi(client, "sess-1")

await api.read(timeout=90.0)

_, kwargs = client.request.call_args
assert kwargs["timeout"] == 90.0

@pytest.mark.asyncio
async def test_default_timeout_on_session_no_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"exists": True})
api = PlanApi(client, "sess-1")

await api.read()

_, kwargs = client.request.call_args
assert "timeout" not in kwargs

# ── server-scoped, with params ─────────────────────────────────────

@pytest.mark.asyncio
async def test_timeout_on_server_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"tools": []})
api = ToolsApi(client)

await api.list(ToolsListParams(), timeout=60.0)

_, kwargs = client.request.call_args
assert kwargs["timeout"] == 60.0

@pytest.mark.asyncio
async def test_default_timeout_on_server_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"tools": []})
api = ToolsApi(client)

await api.list(ToolsListParams())

_, kwargs = client.request.call_args
assert "timeout" not in kwargs

# ── server-scoped, no params ─────────────────────────────────────

@pytest.mark.asyncio
async def test_timeout_on_server_no_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"models": []})
api = ModelsApi(client)

await api.list(timeout=45.0)

_, kwargs = client.request.call_args
assert kwargs["timeout"] == 45.0

@pytest.mark.asyncio
async def test_default_timeout_on_server_no_params_method(self):
client = AsyncMock()
client.request = AsyncMock(return_value={"models": []})
api = ModelsApi(client)

await api.list()

_, kwargs = client.request.call_args
assert "timeout" not in kwargs
23 changes: 15 additions & 8 deletions scripts/codegen/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,14 @@ if TYPE_CHECKING:

`);
lines.push(typesCode);
lines.push(``);
lines.push(`
def _timeout_kwargs(timeout: Optional[float]) -> dict:
"""Build keyword arguments for optional timeout forwarding."""
if timeout is not None:
return {"timeout": timeout}
return {}

`);

// Emit RPC wrapper classes
if (schema.server) {
Expand Down Expand Up @@ -255,10 +262,10 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession:
const hasParams = isSession ? nonSessionParams.length > 0 : Object.keys(paramProps).length > 0;
const paramsType = toPascalCase(method.rpcMethod) + "Params";

// Build signature with typed params
// Build signature with typed params + optional timeout
const sig = hasParams
? ` async def ${methodName}(self, params: ${paramsType}) -> ${resultType}:`
: ` async def ${methodName}(self) -> ${resultType}:`;
? ` async def ${methodName}(self, params: ${paramsType}, *, timeout: Optional[float] = None) -> ${resultType}:`
: ` async def ${methodName}(self, *, timeout: Optional[float] = None) -> ${resultType}:`;

lines.push(sig);

Expand All @@ -267,16 +274,16 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession:
if (hasParams) {
lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`);
lines.push(` params_dict["sessionId"] = self._session_id`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict))`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
} else {
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}))`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))`);
}
} else {
if (hasParams) {
lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict))`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
} else {
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {}))`);
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {}, **_timeout_kwargs(timeout)))`);
}
}
lines.push(``);
Expand Down
Loading