Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions sentry_sdk/integrations/pydantic_ai/patches/agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from sentry_sdk.utils import capture_internal_exceptions, reraise

from ..spans import invoke_agent_span, update_invoke_agent_span
from ..utils import _capture_exception, pop_agent, push_agent
from ..utils import _capture_exception, get_current_agent, pop_agent, push_agent

from typing import TYPE_CHECKING

try:
from pydantic_ai.agent import Agent # type: ignore
from pydantic_ai.run import AgentRun # type: ignore
except ImportError:
raise DidNotEnable("pydantic-ai not installed")

Expand Down Expand Up @@ -40,8 +41,17 @@ def __init__(
self._isolation_scope: "Any" = None
self._span: "Optional[sentry_sdk.tracing.Span]" = None
self._result: "Any" = None
self._is_passthrough: bool = False

async def __aenter__(self) -> "Any":
# Skip instrumentation if there's already an active agent context.
# This happens when run()/run_stream() internally call iter().
if get_current_agent() is not None:
self._is_passthrough = True
result = await self.original_ctx_manager.__aenter__()
self._result = result
return result

# Set up isolation scope and invoke_agent span
self._isolation_scope = sentry_sdk.isolation_scope()
self._isolation_scope.__enter__()
Expand All @@ -56,8 +66,7 @@ async def __aenter__(self) -> "Any":
)
self._span.__enter__()

# Push agent to contextvar stack after span is successfully created and entered
# This ensures proper pairing with pop_agent() in __aexit__ even if exceptions occur
# Push agent to contextvar stack after span is successfully created
push_agent(self.agent, self.is_streaming)

# Enter the original context manager
Expand All @@ -66,13 +75,24 @@ async def __aenter__(self) -> "Any":
return result

async def __aexit__(self, exc_type: "Any", exc_val: "Any", exc_tb: "Any") -> None:
if self._is_passthrough:
await self.original_ctx_manager.__aexit__(exc_type, exc_val, exc_tb)
return

try:
# Exit the original context manager first
await self.original_ctx_manager.__aexit__(exc_type, exc_val, exc_tb)

# Update span with result if successful
if exc_type is None and self._result and self._span is not None:
update_invoke_agent_span(self._span, self._result)
# AgentRun (from iter()) wraps the final result in .result;
# StreamedRunResult (from run_stream()) is used directly.
if isinstance(self._result, AgentRun):
result = self._result.result
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unprotected AgentRun.result access may raise in __aexit__

Medium Severity

Accessing self._result.result on an AgentRun in __aexit__ is unprotected. According to pydantic-ai docs, AgentRun.result is only available after the run reaches an End node. If a user exits the iter() context manager early (e.g., break in the iteration loop), __aexit__ is called with exc_type=None but the run hasn't completed, so .result may raise. This causes instrumentation code to introduce an unexpected exception to the user. The access needs a try/except guard, similar to how update_invoke_agent_span protects .usage() and .response access.

Fix in Cursor Fix in Web

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .result property on AgentRun safely returns None when the iteration hasn't ended — it does not raise. This has been the case since pydantic-ai v1.0.0, which is the minimum version supported by Sentry for this integration (docs).

Here's the implementation from the v1.0.0 tag: pydantic_ai/run.py#L122-L138

else:
result = self._result
if result is not None:
update_invoke_agent_span(self._span, result)
finally:
# Pop agent from contextvar stack
pop_agent()
Expand Down Expand Up @@ -136,9 +156,10 @@ async def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":

def _create_streaming_wrapper(
original_func: "Callable[..., Any]",
is_streaming: bool = True,
) -> "Callable[..., Any]":
"""
Wraps run_stream method that returns an async context manager.
Wraps streaming methods (run_stream, iter) that return async context managers.
"""

@wraps(original_func)
Expand All @@ -158,7 +179,7 @@ def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
user_prompt=user_prompt,
model=model,
model_settings=model_settings,
is_streaming=True,
is_streaming=is_streaming,
)

return wrapper
Expand Down Expand Up @@ -210,3 +231,7 @@ def _patch_agent_run() -> None:
Agent.run_stream_events = _create_streaming_events_wrapper(
original_run_stream_events
)

# Patch iter() - same async context manager pattern as run_stream()
original_iter = Agent.iter
Agent.iter = _create_streaming_wrapper(original_iter, is_streaming=False)
110 changes: 110 additions & 0 deletions tests/integrations/pydantic_ai/test_pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,3 +2794,113 @@ async def test_set_usage_data_with_cache_tokens(sentry_init, capture_events):
(span_data,) = event["spans"]
assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20


@pytest.mark.asyncio
async def test_agent_iter(sentry_init, capture_events, test_agent):
"""
Test that agent.iter() creates an invoke_agent span with chat children.
"""
sentry_init(
integrations=[PydanticAIIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
)

events = capture_events()

async with test_agent.iter("Test input") as agent_run:
async for _node in agent_run:
pass

(transaction,) = events
spans = transaction["spans"]

# Verify transaction
assert transaction["transaction"] == "invoke_agent test_agent"
assert transaction["contexts"]["trace"]["origin"] == "auto.ai.pydantic_ai"

# Find chat spans
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
assert len(chat_spans) >= 1

# iter() is not streaming (node-by-node iteration, not token streaming)
for chat_span in chat_spans:
assert chat_span["data"]["gen_ai.response.streaming"] is False
assert "gen_ai.request.messages" in chat_span["data"]
assert "gen_ai.usage.input_tokens" in chat_span["data"]


@pytest.mark.asyncio
async def test_agent_iter_with_tools(sentry_init, capture_events, test_agent):
"""
Test that tool execution creates execute_tool spans when using iter().
"""

@test_agent.tool_plain
def get_data(query: str) -> str:
"""Return data for a query."""
return f"Result for {query}"

sentry_init(
integrations=[PydanticAIIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
)

events = capture_events()

async with test_agent.iter("Use get_data tool") as agent_run:
async for _node in agent_run:
pass

(transaction,) = events
spans = transaction["spans"]

# Find span types
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]

# Should have tool spans
assert len(tool_spans) >= 1

# Check tool span
tool_span = tool_spans[0]
assert tool_span["data"]["gen_ai.tool.name"] == "get_data"
assert tool_span["data"]["gen_ai.tool.type"] == "function"
assert "gen_ai.tool.input" in tool_span["data"]
assert "gen_ai.tool.output" in tool_span["data"]

# iter() is not streaming
for chat_span in chat_spans:
assert chat_span["data"]["gen_ai.response.streaming"] is False


@pytest.mark.asyncio
async def test_agent_run_no_duplicate_spans(sentry_init, capture_events, test_agent):
"""
Test that agent.run() does not produce duplicate invoke_agent spans.

Since run() internally calls iter(), the passthrough logic must prevent
iter()'s wrapper from creating a second invoke_agent span.
"""
sentry_init(
integrations=[PydanticAIIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
)

events = capture_events()

result = await test_agent.run("Test input")
assert result is not None

(transaction,) = events
spans = transaction["spans"]

# The transaction itself is the invoke_agent span
assert transaction["contexts"]["trace"]["op"] == "gen_ai.invoke_agent"

# There should be NO child invoke_agent spans (passthrough prevents duplicates)
invoke_agent_child_spans = [s for s in spans if s["op"] == "gen_ai.invoke_agent"]
assert len(invoke_agent_child_spans) == 0