Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-langchain"
version = "0.7.7"
version = "0.7.8"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
11 changes: 11 additions & 0 deletions src/uipath_langchain/agent/tools/tool_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
LowCodeAgentDefinition,
)

from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION

from .context_tool import create_context_tool
from .escalation_tool import create_escalation_tool
from .extraction_tool import create_ixp_extraction_tool
Expand Down Expand Up @@ -54,6 +56,15 @@ async def create_tools_from_resources(
else:
tools.append(tool)

if agent.is_conversational:
props = getattr(resource, "properties", None)
if props and getattr(
props, REQUIRE_CONVERSATIONAL_CONFIRMATION, False
):
if tool.metadata is None:
tool.metadata = {}
tool.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True

return tools


Expand Down
19 changes: 17 additions & 2 deletions src/uipath_langchain/agent/tools/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
extract_current_tool_call_index,
find_latest_ai_message,
)
from uipath_langchain.chat.hitl import check_tool_confirmation

# the type safety can be improved with generics
ToolWrapperReturnType = dict[str, Any] | Command[Any] | None
Expand Down Expand Up @@ -79,6 +80,10 @@ def _func(self, state: AgentGraphState) -> OutputType:
if call is None:
return None

confirmation = check_tool_confirmation(call, self.tool)
if confirmation is not None and confirmation.cancelled:
return self._process_result(call, confirmation.cancelled)

try:
if self.wrapper:
inputs = self._prepare_wrapper_inputs(
Expand All @@ -87,7 +92,10 @@ def _func(self, state: AgentGraphState) -> OutputType:
result = self.wrapper(*inputs)
else:
result = self.tool.invoke(call)
return self._process_result(call, result)
output = self._process_result(call, result)
if confirmation is not None:
confirmation.annotate_result(output)
return output
except Exception as e:
if self.handle_tool_errors:
return self._process_error_result(call, e)
Expand All @@ -98,6 +106,10 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
if call is None:
return None

confirmation = check_tool_confirmation(call, self.tool)
if confirmation is not None and confirmation.cancelled:
return self._process_result(call, confirmation.cancelled)

try:
if self.awrapper:
inputs = self._prepare_wrapper_inputs(
Expand All @@ -106,7 +118,10 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
result = await self.awrapper(*inputs)
else:
result = await self.tool.ainvoke(call)
return self._process_result(call, result)
output = self._process_result(call, result)
if confirmation is not None:
confirmation.annotate_result(output)
return output
except Exception as e:
if self.handle_tool_errors:
return self._process_error_result(call, e)
Expand Down
74 changes: 68 additions & 6 deletions src/uipath_langchain/chat/hitl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
import functools
import inspect
from inspect import Parameter
from typing import Annotated, Any, Callable
from typing import Annotated, Any, Callable, NamedTuple

from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.tools import BaseTool, InjectedToolCallId
from langchain_core.tools import tool as langchain_tool
from langgraph.types import interrupt
from uipath.core.chat import (
UiPathConversationToolCallConfirmationValue,
)

_CANCELLED_MESSAGE = "Cancelled by user"
CANCELLED_MESSAGE = "Cancelled by user"
ARGS_MODIFIED_MESSAGE = "Tool arguments were modified by the user"
CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args"
REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation"


class ConfirmationResult(NamedTuple):
"""Result of a tool confirmation check."""

cancelled: ToolMessage | None # ToolMessage if cancelled, None if approved
args_modified: bool
approved_args: dict[str, Any] | None = None

def annotate_result(self, output: dict[str, Any] | Any) -> None:
"""Apply confirmation metadata to a tool result message."""
msg = None
if isinstance(output, dict):
messages = output.get("messages")
if messages:
msg = messages[0]
if msg is None:
return
if self.approved_args is not None:
msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
self.approved_args
)
if self.args_modified:
msg.content = (
f'{{"meta": "{ARGS_MODIFIED_MESSAGE}", "result": {msg.content}}}'
)


def _patch_span_input(approved_args: dict[str, Any]) -> None:
Expand Down Expand Up @@ -53,7 +83,7 @@ def _patch_span_input(approved_args: dict[str, Any]) -> None:
pass


def _request_approval(
def request_approval(
tool_args: dict[str, Any],
tool: BaseTool,
) -> dict[str, Any] | None:
Expand Down Expand Up @@ -89,7 +119,39 @@ def _request_approval(
if not confirmation.get("approved", True):
return None

return confirmation.get("input") or tool_args
return (
confirmation.get("input")
if confirmation.get("input") is not None
else tool_args
)


def check_tool_confirmation(
call: ToolCall, tool: BaseTool
) -> ConfirmationResult | None:
if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)):
return None

original_args = call["args"]
approved_args = request_approval(
{**original_args, "tool_call_id": call["id"]}, tool
)
if approved_args is None:
cancelled_msg = ToolMessage(
content=CANCELLED_MESSAGE,
name=call["name"],
tool_call_id=call["id"],
)
cancelled_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
original_args
)
return ConfirmationResult(cancelled=cancelled_msg, args_modified=False)
call["args"] = approved_args
return ConfirmationResult(
cancelled=None,
args_modified=approved_args != original_args,
approved_args=approved_args,
)


def requires_approval(
Expand All @@ -107,9 +169,9 @@ def decorator(fn: Callable[..., Any]) -> BaseTool:
# wrap the tool/function
@functools.wraps(fn)
def wrapper(**tool_args: Any) -> Any:
approved_args = _request_approval(tool_args, _created_tool[0])
approved_args = request_approval(tool_args, _created_tool[0])
if approved_args is None:
return _CANCELLED_MESSAGE
return {"meta": CANCELLED_MESSAGE}
_patch_span_input(approved_args)
return fn(**approved_args)

Expand Down
30 changes: 24 additions & 6 deletions src/uipath_langchain/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
)
from uipath.runtime import UiPathRuntimeStorageProtocol

from uipath_langchain.chat.hitl import CONVERSATIONAL_APPROVED_TOOL_ARGS

from ._citations import CitationStreamProcessor, extract_citations_from_text

logger = logging.getLogger(__name__)
Expand All @@ -58,6 +60,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
"""Initialize the mapper with empty state."""
self.runtime_id = runtime_id
self.storage = storage
self.confirmation_tool_names: set[str] = set[str]()
self.current_message: AIMessageChunk
self.seen_message_ids: set[str] = set()
self._storage_lock = asyncio.Lock()
Expand Down Expand Up @@ -389,11 +392,14 @@ async def map_current_message_to_start_tool_call_events(self):
tool_call_id_to_message_id_map[tool_call_id] = (
self.current_message.id
)
events.append(
self.map_tool_call_to_tool_call_start_event(
self.current_message.id, tool_call

if tool_call["name"] in self.confirmation_tool_names:
# defer tool call for HITL
events.append(
self.map_tool_call_to_tool_call_start_event(
self.current_message.id, tool_call
)
)
)

if self.storage is not None:
await self.storage.set_value(
Expand Down Expand Up @@ -426,7 +432,19 @@ async def map_tool_message_to_events(
# Keep as string if not valid JSON
pass

events = [
events: list[UiPathConversationMessageEvent] = []

# Emit deferred startToolCall for confirmation tools (skipped in Pass 1)
approved_args = message.response_metadata.get(CONVERSATIONAL_APPROVED_TOOL_ARGS)
if approved_args is not None:
tool_call = ToolCall(
name=message.name or "", args=approved_args, id=message.tool_call_id
)
events.append(
self.map_tool_call_to_tool_call_start_event(message_id, tool_call)
)

events.append(
UiPathConversationMessageEvent(
message_id=message_id,
tool_call=UiPathConversationToolCallEvent(
Expand All @@ -438,7 +456,7 @@ async def map_tool_message_to_events(
),
),
)
]
)

if is_last_tool_call:
events.append(self.map_to_message_end_event(message_id))
Expand Down
16 changes: 16 additions & 0 deletions src/uipath_langchain/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from uipath.runtime.schema import UiPathRuntimeSchema

from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
self.entrypoint: str | None = entrypoint
self.callbacks: list[BaseCallbackHandler] = callbacks or []
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
self.chat.confirmation_tool_names = self._detect_confirmation_tools()
self._middleware_node_names: set[str] = self._detect_middleware_nodes()

async def execute(
Expand Down Expand Up @@ -486,6 +488,20 @@ def _detect_middleware_nodes(self) -> set[str]:

return middleware_nodes

def _detect_confirmation_tools(self) -> set[str]:
confirmation_tools: set[str] = set()
for node_name, node_spec in self.graph.nodes.items():
bound = getattr(node_spec, "bound", None)
if bound is None:
continue
tool = getattr(bound, "tool", None)
if tool is None:
continue
metadata = getattr(tool, "metadata", None) or {}
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
confirmation_tools.add(getattr(tool, "name", node_name))
return confirmation_tools

def _is_middleware_node(self, node_name: str) -> bool:
"""Check if a node name represents a middleware node."""
return node_name in self._middleware_node_names
Expand Down
Loading