diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 5acc501172..0f104cc8f5 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -30,7 +30,7 @@ class GEN_AI_ALLOWED_MESSAGE_ROLES: GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = { GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"], GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"], - GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"], + GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai", "chatbot"], GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"], } diff --git a/sentry_sdk/integrations/cohere.py b/sentry_sdk/integrations/cohere.py deleted file mode 100644 index f45a02f2b5..0000000000 --- a/sentry_sdk/integrations/cohere.py +++ /dev/null @@ -1,269 +0,0 @@ -import sys -from functools import wraps - -from sentry_sdk import consts -from sentry_sdk.ai.monitoring import record_token_usage -from sentry_sdk.consts import SPANDATA -from sentry_sdk.ai.utils import set_data_normalized - -from typing import TYPE_CHECKING - -from sentry_sdk.tracing_utils import set_span_errored - -if TYPE_CHECKING: - from typing import Any, Callable, Iterator - from sentry_sdk.tracing import Span - -import sentry_sdk -from sentry_sdk.scope import should_send_default_pii -from sentry_sdk.integrations import DidNotEnable, Integration -from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise - -try: - from cohere.client import Client - from cohere.base_client import BaseCohere - from cohere import ( - ChatStreamEndEvent, - NonStreamedChatResponse, - ) - - if TYPE_CHECKING: - from cohere import StreamedChatResponse -except ImportError: - raise DidNotEnable("Cohere not installed") - -try: - # cohere 5.9.3+ - from cohere import StreamEndStreamedChatResponse -except ImportError: - from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse - - -COLLECTED_CHAT_PARAMS = { - "model": SPANDATA.AI_MODEL_ID, - "k": SPANDATA.AI_TOP_K, - "p": SPANDATA.AI_TOP_P, - "seed": SPANDATA.AI_SEED, - "frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY, - "presence_penalty": SPANDATA.AI_PRESENCE_PENALTY, - "raw_prompting": SPANDATA.AI_RAW_PROMPTING, -} - -COLLECTED_PII_CHAT_PARAMS = { - "tools": SPANDATA.AI_TOOLS, - "preamble": SPANDATA.AI_PREAMBLE, -} - -COLLECTED_CHAT_RESP_ATTRS = { - "generation_id": SPANDATA.AI_GENERATION_ID, - "is_search_required": SPANDATA.AI_SEARCH_REQUIRED, - "finish_reason": SPANDATA.AI_FINISH_REASON, -} - -COLLECTED_PII_CHAT_RESP_ATTRS = { - "citations": SPANDATA.AI_CITATIONS, - "documents": SPANDATA.AI_DOCUMENTS, - "search_queries": SPANDATA.AI_SEARCH_QUERIES, - "search_results": SPANDATA.AI_SEARCH_RESULTS, - "tool_calls": SPANDATA.AI_TOOL_CALLS, -} - - -class CohereIntegration(Integration): - identifier = "cohere" - origin = f"auto.ai.{identifier}" - - def __init__(self: "CohereIntegration", include_prompts: bool = True) -> None: - self.include_prompts = include_prompts - - @staticmethod - def setup_once() -> None: - BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) - Client.embed = _wrap_embed(Client.embed) - BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) - - -def _capture_exception(exc: "Any") -> None: - set_span_errored() - - event, hint = event_from_exception( - exc, - client_options=sentry_sdk.get_client().options, - mechanism={"type": "cohere", "handled": False}, - ) - sentry_sdk.capture_event(event, hint=hint) - - -def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]": - def collect_chat_response_fields( - span: "Span", res: "NonStreamedChatResponse", include_pii: bool - ) -> None: - if include_pii: - if hasattr(res, "text"): - set_data_normalized( - span, - SPANDATA.AI_RESPONSES, - [res.text], - ) - for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS: - if hasattr(res, pii_attr): - set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr)) - - for attr in COLLECTED_CHAT_RESP_ATTRS: - if hasattr(res, attr): - set_data_normalized(span, "ai." + attr, getattr(res, attr)) - - if hasattr(res, "meta"): - if hasattr(res.meta, "billed_units"): - record_token_usage( - span, - input_tokens=res.meta.billed_units.input_tokens, - output_tokens=res.meta.billed_units.output_tokens, - ) - elif hasattr(res.meta, "tokens"): - record_token_usage( - span, - input_tokens=res.meta.tokens.input_tokens, - output_tokens=res.meta.tokens.output_tokens, - ) - - if hasattr(res.meta, "warnings"): - set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings) - - @wraps(f) - def new_chat(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration(CohereIntegration) - - if ( - integration is None - or "message" not in kwargs - or not isinstance(kwargs.get("message"), str) - ): - return f(*args, **kwargs) - - message = kwargs.get("message") - - span = sentry_sdk.start_span( - op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE, - name="cohere.client.Chat", - origin=CohereIntegration.origin, - ) - span.__enter__() - try: - res = f(*args, **kwargs) - except Exception as e: - exc_info = sys.exc_info() - with capture_internal_exceptions(): - _capture_exception(e) - span.__exit__(None, None, None) - reraise(*exc_info) - - with capture_internal_exceptions(): - if should_send_default_pii() and integration.include_prompts: - set_data_normalized( - span, - SPANDATA.AI_INPUT_MESSAGES, - list( - map( - lambda x: { - "role": getattr(x, "role", "").lower(), - "content": getattr(x, "message", ""), - }, - kwargs.get("chat_history", []), - ) - ) - + [{"role": "user", "content": message}], - ) - for k, v in COLLECTED_PII_CHAT_PARAMS.items(): - if k in kwargs: - set_data_normalized(span, v, kwargs[k]) - - for k, v in COLLECTED_CHAT_PARAMS.items(): - if k in kwargs: - set_data_normalized(span, v, kwargs[k]) - set_data_normalized(span, SPANDATA.AI_STREAMING, False) - - if streaming: - old_iterator = res - - def new_iterator() -> "Iterator[StreamedChatResponse]": - with capture_internal_exceptions(): - for x in old_iterator: - if isinstance(x, ChatStreamEndEvent) or isinstance( - x, StreamEndStreamedChatResponse - ): - collect_chat_response_fields( - span, - x.response, - include_pii=should_send_default_pii() - and integration.include_prompts, - ) - yield x - - span.__exit__(None, None, None) - - return new_iterator() - elif isinstance(res, NonStreamedChatResponse): - collect_chat_response_fields( - span, - res, - include_pii=should_send_default_pii() - and integration.include_prompts, - ) - span.__exit__(None, None, None) - else: - set_data_normalized(span, "unknown_response", True) - span.__exit__(None, None, None) - return res - - return new_chat - - -def _wrap_embed(f: "Callable[..., Any]") -> "Callable[..., Any]": - @wraps(f) - def new_embed(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration(CohereIntegration) - if integration is None: - return f(*args, **kwargs) - - with sentry_sdk.start_span( - op=consts.OP.COHERE_EMBEDDINGS_CREATE, - name="Cohere Embedding Creation", - origin=CohereIntegration.origin, - ) as span: - if "texts" in kwargs and ( - should_send_default_pii() and integration.include_prompts - ): - if isinstance(kwargs["texts"], str): - set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]]) - elif ( - isinstance(kwargs["texts"], list) - and len(kwargs["texts"]) > 0 - and isinstance(kwargs["texts"][0], str) - ): - set_data_normalized( - span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"] - ) - - if "model" in kwargs: - set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"]) - try: - res = f(*args, **kwargs) - except Exception as e: - exc_info = sys.exc_info() - with capture_internal_exceptions(): - _capture_exception(e) - reraise(*exc_info) - if ( - hasattr(res, "meta") - and hasattr(res.meta, "billed_units") - and hasattr(res.meta.billed_units, "input_tokens") - ): - record_token_usage( - span, - input_tokens=res.meta.billed_units.input_tokens, - total_tokens=res.meta.billed_units.input_tokens, - ) - return res - - return new_embed diff --git a/sentry_sdk/integrations/cohere/__init__.py b/sentry_sdk/integrations/cohere/__init__.py new file mode 100644 index 0000000000..f8a26f1fc8 --- /dev/null +++ b/sentry_sdk/integrations/cohere/__init__.py @@ -0,0 +1,127 @@ +import sys +from functools import wraps + +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.ai.utils import set_data_normalized + +from typing import TYPE_CHECKING + +from sentry_sdk.tracing_utils import set_span_errored + +if TYPE_CHECKING: + from typing import Any, Callable + +import sentry_sdk +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise + +try: + from cohere import __version__ as cohere_version # noqa: F401 +except ImportError: + raise DidNotEnable("Cohere not installed") + +COLLECTED_CHAT_PARAMS = { + "model": SPANDATA.GEN_AI_REQUEST_MODEL, + "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE, + "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, + "k": SPANDATA.GEN_AI_REQUEST_TOP_K, + "p": SPANDATA.GEN_AI_REQUEST_TOP_P, + "seed": SPANDATA.GEN_AI_REQUEST_SEED, + "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, + "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, +} + + +class CohereIntegration(Integration): + identifier = "cohere" + origin = f"auto.ai.{identifier}" + + def __init__(self, include_prompts=True): + # type: (bool) -> None + self.include_prompts = include_prompts + + @staticmethod + def setup_once(): + # type: () -> None + # Lazy imports to avoid circular dependencies: + # v1/v2 import COLLECTED_CHAT_PARAMS and _capture_exception from this module. + from sentry_sdk.integrations.cohere.v1 import setup_v1 + from sentry_sdk.integrations.cohere.v2 import setup_v2 + + setup_v1(_wrap_embed) + setup_v2(_wrap_embed) + + +def _capture_exception(exc): + # type: (Any) -> None + set_span_errored() + + event, hint = event_from_exception( + exc, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "cohere", "handled": False}, + ) + sentry_sdk.capture_event(event, hint=hint) + + +def _wrap_embed(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + def new_embed(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + if integration is None: + return f(*args, **kwargs) + + model = kwargs.get("model", "") + + with sentry_sdk.start_span( + op=OP.GEN_AI_EMBEDDINGS, + name=f"embeddings {model}".strip(), + origin=CohereIntegration.origin, + ) as span: + set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere") + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") + + if "texts" in kwargs and ( + should_send_default_pii() and integration.include_prompts + ): + if isinstance(kwargs["texts"], str): + set_data_normalized( + span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, [kwargs["texts"]] + ) + elif ( + isinstance(kwargs["texts"], list) + and len(kwargs["texts"]) > 0 + and isinstance(kwargs["texts"][0], str) + ): + set_data_normalized( + span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, kwargs["texts"] + ) + + if "model" in kwargs: + set_data_normalized( + span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"] + ) + try: + res = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + reraise(*exc_info) + if ( + hasattr(res, "meta") + and hasattr(res.meta, "billed_units") + and hasattr(res.meta.billed_units, "input_tokens") + ): + record_token_usage( + span, + input_tokens=res.meta.billed_units.input_tokens, + total_tokens=res.meta.billed_units.input_tokens, + ) + return res + + return new_embed diff --git a/sentry_sdk/integrations/cohere/v1.py b/sentry_sdk/integrations/cohere/v1.py new file mode 100644 index 0000000000..c0e33cfbc7 --- /dev/null +++ b/sentry_sdk/integrations/cohere/v1.py @@ -0,0 +1,210 @@ +import sys +from functools import wraps + +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.ai.utils import ( + set_data_normalized, + normalize_message_roles, + truncate_and_annotate_messages, +) + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, Iterator + from sentry_sdk.tracing import Span + +import sentry_sdk +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.utils import capture_internal_exceptions, reraise + +from sentry_sdk.integrations.cohere import ( + CohereIntegration, + COLLECTED_CHAT_PARAMS, + _capture_exception, +) + +COLLECTED_PII_CHAT_PARAMS = { + "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, + "preamble": SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS, +} + +COLLECTED_CHAT_RESP_ATTRS = { + "generation_id": SPANDATA.GEN_AI_RESPONSE_ID, + "finish_reason": SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, +} + +COLLECTED_PII_CHAT_RESP_ATTRS = { + "tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, +} + + +def setup_v1(wrap_embed_fn): + # type: (Callable[..., Any]) -> None + """Called from CohereIntegration.setup_once() to patch V1 Client methods.""" + try: + from cohere.client import Client + from cohere.base_client import BaseCohere + except ImportError: + return + + BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) + BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) + Client.embed = wrap_embed_fn(Client.embed) + + +def _wrap_chat(f, streaming): + # type: (Callable[..., Any], bool) -> Callable[..., Any] + + try: + from cohere import ( + ChatStreamEndEvent, + NonStreamedChatResponse, + ) + + if TYPE_CHECKING: + from cohere import StreamedChatResponse + except ImportError: + return f + + try: + # cohere 5.9.3+ + from cohere import StreamEndStreamedChatResponse + except ImportError: + from cohere import ( + StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse, + ) + + def collect_chat_response_fields(span, res, include_pii): + # type: (Span, NonStreamedChatResponse, bool) -> None + if include_pii: + if hasattr(res, "text"): + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + [res.text], + ) + for attr, spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS.items(): + if hasattr(res, attr): + set_data_normalized(span, spandata_key, getattr(res, attr)) + + for attr, spandata_key in COLLECTED_CHAT_RESP_ATTRS.items(): + if hasattr(res, attr): + set_data_normalized(span, spandata_key, getattr(res, attr)) + + if hasattr(res, "meta"): + if hasattr(res.meta, "billed_units"): + record_token_usage( + span, + input_tokens=res.meta.billed_units.input_tokens, + output_tokens=res.meta.billed_units.output_tokens, + ) + elif hasattr(res.meta, "tokens"): + record_token_usage( + span, + input_tokens=res.meta.tokens.input_tokens, + output_tokens=res.meta.tokens.output_tokens, + ) + + @wraps(f) + def new_chat(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + + if ( + integration is None + or "message" not in kwargs + or not isinstance(kwargs.get("message"), str) + ): + return f(*args, **kwargs) + + message = kwargs.get("message") + model = kwargs.get("model", "") + + span = sentry_sdk.start_span( + op=OP.GEN_AI_CHAT, + name=f"chat {model}".strip(), + origin=CohereIntegration.origin, + ) + span.__enter__() + try: + res = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + span.__exit__(None, None, None) + reraise(*exc_info) + + with capture_internal_exceptions(): + set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere") + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat") + if model: + set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model) + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, model) + + if should_send_default_pii() and integration.include_prompts: + messages = [] + for x in kwargs.get("chat_history", []): + messages.append( + { + "role": getattr(x, "role", "").lower(), + "content": getattr(x, "message", ""), + } + ) + messages.append({"role": "user", "content": message}) + messages = normalize_message_roles(messages) + scope = sentry_sdk.get_current_scope() + messages_data = truncate_and_annotate_messages(messages, span, scope) + if messages_data is not None: + set_data_normalized( + span, + SPANDATA.GEN_AI_REQUEST_MESSAGES, + messages_data, + unpack=False, + ) + for k, v in COLLECTED_PII_CHAT_PARAMS.items(): + if k in kwargs: + set_data_normalized(span, v, kwargs[k]) + + for k, v in COLLECTED_CHAT_PARAMS.items(): + if k in kwargs: + set_data_normalized(span, v, kwargs[k]) + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, streaming) + + if streaming: + old_iterator = res + + def new_iterator(): + # type: () -> Iterator[StreamedChatResponse] + with capture_internal_exceptions(): + for x in old_iterator: + if isinstance(x, ChatStreamEndEvent) or isinstance( + x, StreamEndStreamedChatResponse + ): + collect_chat_response_fields( + span, + x.response, + include_pii=should_send_default_pii() + and integration.include_prompts, + ) + yield x + + span.__exit__(None, None, None) + + return new_iterator() + elif isinstance(res, NonStreamedChatResponse): + collect_chat_response_fields( + span, + res, + include_pii=should_send_default_pii() + and integration.include_prompts, + ) + span.__exit__(None, None, None) + else: + set_data_normalized(span, "unknown_response", True) + span.__exit__(None, None, None) + return res + + return new_chat diff --git a/sentry_sdk/integrations/cohere/v2.py b/sentry_sdk/integrations/cohere/v2.py new file mode 100644 index 0000000000..0a12828462 --- /dev/null +++ b/sentry_sdk/integrations/cohere/v2.py @@ -0,0 +1,280 @@ +import sys +from functools import wraps + +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.ai.utils import ( + set_data_normalized, + normalize_message_roles, + truncate_and_annotate_messages, +) + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, Iterator + from sentry_sdk.tracing import Span + +import sentry_sdk +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.utils import capture_internal_exceptions, reraise + +from sentry_sdk.integrations.cohere import ( + CohereIntegration, + COLLECTED_CHAT_PARAMS, + _capture_exception, +) + +try: + from cohere.v2.client import V2Client as CohereV2Client + + # Type locations changed between cohere versions: + # 5.13.x: cohere.types (ChatResponse, MessageEndStreamedChatResponseV2) + # 5.20+: cohere.v2.types (V2ChatResponse, MessageEndV2ChatStreamResponse) + try: + from cohere.v2.types import V2ChatResponse + from cohere.v2.types import MessageEndV2ChatStreamResponse + + if TYPE_CHECKING: + from cohere.v2.types import V2ChatStreamResponse + except ImportError: + from cohere.types import ChatResponse as V2ChatResponse + from cohere.types import ( + MessageEndStreamedChatResponseV2 as MessageEndV2ChatStreamResponse, + ) + + if TYPE_CHECKING: + from cohere.types import StreamedChatResponseV2 as V2ChatStreamResponse + + _has_v2 = True +except ImportError: + _has_v2 = False + + +def setup_v2(wrap_embed_fn): + # type: (Callable[..., Any]) -> None + """Called from CohereIntegration.setup_once() to patch V2Client methods. + + The embed wrapper is passed in from __init__.py to reuse the same _wrap_embed + for both V1 and V2, since the embed response format (.meta.billed_units) + is identical across both API versions. + """ + if not _has_v2: + return + + CohereV2Client.chat = _wrap_chat_v2(CohereV2Client.chat, streaming=False) + CohereV2Client.chat_stream = _wrap_chat_v2( + CohereV2Client.chat_stream, streaming=True + ) + CohereV2Client.embed = wrap_embed_fn(CohereV2Client.embed) + + +def _extract_messages_v2(messages): + # type: (Any) -> list[dict[str, str]] + """Extract role/content dicts from V2-style message objects. + + Handles both plain dicts and Pydantic model instances. + """ + result = [] + for msg in messages: + if isinstance(msg, dict): + role = msg.get("role", "unknown") + content = msg.get("content", "") + else: + role = getattr(msg, "role", "unknown") + content = getattr(msg, "content", "") + if isinstance(content, str): + text = content + elif isinstance(content, list): + text = " ".join( + ( + item.get("text", "") + if isinstance(item, dict) + else getattr(item, "text", "") + ) + for item in content + if (isinstance(item, dict) and "text" in item) or hasattr(item, "text") + ) + else: + text = str(content) if content else "" + result.append({"role": role, "content": text}) + return result + + +def _record_token_usage_v2(span, usage): + # type: (Span, Any) -> None + """Extract and record token usage from a V2 Usage object.""" + if hasattr(usage, "billed_units") and usage.billed_units is not None: + record_token_usage( + span, + input_tokens=getattr(usage.billed_units, "input_tokens", None), + output_tokens=getattr(usage.billed_units, "output_tokens", None), + ) + elif hasattr(usage, "tokens") and usage.tokens is not None: + record_token_usage( + span, + input_tokens=getattr(usage.tokens, "input_tokens", None), + output_tokens=getattr(usage.tokens, "output_tokens", None), + ) + + +def _wrap_chat_v2(f, streaming): + # type: (Callable[..., Any], bool) -> Callable[..., Any] + def collect_v2_response_fields(span, res, include_pii): + # type: (Span, V2ChatResponse, bool) -> None + if include_pii: + if ( + hasattr(res, "message") + and hasattr(res.message, "content") + and res.message.content + ): + texts = [ + item.text for item in res.message.content if hasattr(item, "text") + ] + if texts: + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, texts) + + if ( + hasattr(res, "message") + and hasattr(res.message, "tool_calls") + and res.message.tool_calls + ): + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, + res.message.tool_calls, + ) + + if hasattr(res, "id"): + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_ID, res.id) + + if hasattr(res, "finish_reason"): + set_data_normalized( + span, SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, res.finish_reason + ) + + if hasattr(res, "usage") and res.usage is not None: + _record_token_usage_v2(span, res.usage) + + @wraps(f) + def new_chat(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + + if integration is None or "messages" not in kwargs: + return f(*args, **kwargs) + + model = kwargs.get("model", "") + + span = sentry_sdk.start_span( + op=OP.GEN_AI_CHAT, + name=f"chat {model}".strip(), + origin=CohereIntegration.origin, + ) + span.__enter__() + try: + res = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + span.__exit__(None, None, None) + reraise(*exc_info) + + with capture_internal_exceptions(): + set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere") + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat") + if model: + set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model) + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, model) + + if should_send_default_pii() and integration.include_prompts: + messages = _extract_messages_v2(kwargs.get("messages", [])) + messages = normalize_message_roles(messages) + scope = sentry_sdk.get_current_scope() + messages_data = truncate_and_annotate_messages(messages, span, scope) + if messages_data is not None: + set_data_normalized( + span, + SPANDATA.GEN_AI_REQUEST_MESSAGES, + messages_data, + unpack=False, + ) + if "tools" in kwargs: + set_data_normalized( + span, + SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, + kwargs["tools"], + ) + + for k, v in COLLECTED_CHAT_PARAMS.items(): + if k in kwargs: + set_data_normalized(span, v, kwargs[k]) + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, streaming) + + if streaming: + old_iterator = res + + def new_iterator(): + # type: () -> Iterator[V2ChatStreamResponse] + collected_text = [] + with capture_internal_exceptions(): + for x in old_iterator: + if ( + hasattr(x, "type") + and x.type == "content-delta" + and hasattr(x, "delta") + and x.delta is not None + ): + msg = getattr(x.delta, "message", None) + if msg is not None: + content = getattr(msg, "content", None) + if content is not None and hasattr(content, "text"): + collected_text.append(content.text) + + if isinstance(x, MessageEndV2ChatStreamResponse): + include_pii = ( + should_send_default_pii() + and integration.include_prompts + ) + if include_pii and collected_text: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + ["".join(collected_text)], + ) + if hasattr(x, "id"): + set_data_normalized( + span, SPANDATA.GEN_AI_RESPONSE_ID, x.id + ) + if hasattr(x, "delta") and x.delta is not None: + if hasattr(x.delta, "finish_reason"): + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, + x.delta.finish_reason, + ) + if ( + hasattr(x.delta, "usage") + and x.delta.usage is not None + ): + _record_token_usage_v2(span, x.delta.usage) + yield x + + span.__exit__(None, None, None) + + return new_iterator() + elif isinstance(res, V2ChatResponse): + collect_v2_response_fields( + span, + res, + include_pii=should_send_default_pii() + and integration.include_prompts, + ) + span.__exit__(None, None, None) + else: + set_data_normalized(span, "unknown_response", True) + span.__exit__(None, None, None) + return res + + return new_chat diff --git a/tests/integrations/cohere/test_cohere.py b/tests/integrations/cohere/test_cohere.py index 9ff56ed697..3da2f616ed 100644 --- a/tests/integrations/cohere/test_cohere.py +++ b/tests/integrations/cohere/test_cohere.py @@ -2,21 +2,32 @@ import httpx import pytest +from unittest import mock + +from httpx import Client as HTTPXClient + from cohere import Client, ChatMessage from sentry_sdk import start_transaction from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations.cohere import CohereIntegration -from unittest import mock # python 3.3 and above -from httpx import Client as HTTPXClient +try: + from cohere import ClientV2 + + has_v2 = True +except ImportError: + has_v2 = False + + +# --- V1 Chat (non-streaming) --- @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_nonstreaming_chat( +def test_v1_nonstreaming_chat( sentry_init, capture_events, send_default_pii, include_prompts ): sentry_init( @@ -32,6 +43,8 @@ def test_nonstreaming_chat( 200, json={ "text": "the model response", + "generation_id": "gen-123", + "finish_reason": "COMPLETE", "meta": { "billed_units": { "output_tokens": 10, @@ -47,40 +60,41 @@ def test_nonstreaming_chat( model="some-model", chat_history=[ChatMessage(role="SYSTEM", message="some context")], message="hello", - ).text + ) - assert response == "the model response" + assert response.text == "the model response" tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.cohere" - assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model" + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False if send_default_pii and include_prompts: - assert ( - '{"role": "system", "content": "some context"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert ( - '{"role": "user", "content": "hello"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] assert span["data"]["gen_ai.usage.output_tokens"] == 10 assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30 -# noinspection PyTypeChecker +# --- V1 Chat (streaming) --- + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_prompts): +def test_v1_streaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): sentry_init( integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, @@ -102,6 +116,7 @@ def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_p "finish_reason": "COMPLETE", "response": { "text": "the model response", + "generation_id": "gen-123", "meta": { "billed_units": { "output_tokens": 10, @@ -130,29 +145,29 @@ def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_p tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.cohere" - assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model" + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True if send_default_pii and include_prompts: - assert ( - '{"role": "system", "content": "some context"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert ( - '{"role": "user", "content": "hello"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] assert span["data"]["gen_ai.usage.output_tokens"] == 10 assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30 -def test_bad_chat(sentry_init, capture_events): +# --- V1 Error --- + + +def test_v1_bad_chat(sentry_init, capture_events): sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -167,7 +182,7 @@ def test_bad_chat(sentry_init, capture_events): assert event["level"] == "error" -def test_span_status_error(sentry_init, capture_events): +def test_v1_span_status_error(sentry_init, capture_events): sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -186,11 +201,14 @@ def test_span_status_error(sentry_init, capture_events): assert transaction["contexts"]["trace"]["status"] == "internal_error" +# --- V1 Embed --- + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_embed(sentry_init, capture_events, send_default_pii, include_prompts): +def test_v1_embed(sentry_init, capture_events, send_default_pii, include_prompts): sentry_init( integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, @@ -217,67 +235,216 @@ def test_embed(sentry_init, capture_events, send_default_pii, include_prompts): ) with start_transaction(name="cohere tx"): - response = client.embed(texts=["hello"], model="text-embedding-3-large") + response = client.embed(texts=["hello"], model="embed-english-v3.0") assert len(response.embeddings[0]) == 3 tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.embeddings.create.cohere" + assert span["op"] == "gen_ai.embeddings" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings" + if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES] + assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] assert span["data"]["gen_ai.usage.input_tokens"] == 10 assert span["data"]["gen_ai.usage.total_tokens"] == 10 -def test_span_origin_chat(sentry_init, capture_events): +# --- V2 Chat (non-streaming) --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_nonstreaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): sentry_init( - integrations=[CohereIntegration()], + integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, + send_default_pii=send_default_pii, ) events = capture_events() - client = Client(api_key="z") + client = ClientV2(api_key="z") HTTPXClient.request = mock.Mock( return_value=httpx.Response( 200, json={ - "text": "the model response", - "meta": { + "id": "resp-123", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "the model response"}], + }, + "usage": { "billed_units": { - "output_tokens": 10, "input_tokens": 20, - } + "output_tokens": 10, + }, + "tokens": { + "input_tokens": 25, + "output_tokens": 15, + }, }, }, ) ) with start_transaction(name="cohere tx"): + response = client.chat( + model="some-model", + messages=[ + {"role": "system", "content": "some context"}, + {"role": "user", "content": "hello"}, + ], + ) + + assert response.message.content[0].text == "the model response" + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False + assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "resp-123" + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] + else: + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] + + assert span["data"]["gen_ai.usage.output_tokens"] == 10 + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +# --- V2 Chat (streaming) --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_streaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[CohereIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = ClientV2(api_key="z") + + # SSE format: each event is "data: ...\n\n" + sse_content = "".join( + [ + 'data: {"type":"message-start","id":"resp-123"}\n', + "\n", + 'data: {"type":"content-delta","index":0,"delta":{"type":"content-delta","message":{"role":"assistant","content":{"type":"text","text":"the model "}}}}\n', + "\n", + 'data: {"type":"content-delta","index":0,"delta":{"type":"content-delta","message":{"role":"assistant","content":{"type":"text","text":"response"}}}}\n', + "\n", + 'data: {"type":"message-end","id":"resp-123","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":20,"output_tokens":10},"tokens":{"input_tokens":25,"output_tokens":15}}}}\n', + "\n", + ] + ) + + HTTPXClient.send = mock.Mock( + return_value=httpx.Response( + 200, + content=sse_content, + headers={"content-type": "text/event-stream"}, + ) + ) + + with start_transaction(name="cohere tx"): + responses = list( + client.chat_stream( + model="some-model", + messages=[ + {"role": "user", "content": "hello"}, + ], + ) + ) + + assert len(responses) > 0 + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] + else: + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] + + assert span["data"]["gen_ai.usage.output_tokens"] == 10 + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +# --- V2 Error --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +def test_v2_bad_chat(sentry_init, capture_events): + sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) + events = capture_events() + + client = ClientV2(api_key="z") + HTTPXClient.request = mock.Mock( + side_effect=httpx.HTTPError("API rate limit reached") + ) + with pytest.raises(httpx.HTTPError): client.chat( model="some-model", - chat_history=[ChatMessage(role="SYSTEM", message="some context")], - message="hello", - ).text + messages=[{"role": "user", "content": "hello"}], + ) (event,) = events + assert event["level"] == "error" + - assert event["contexts"]["trace"]["origin"] == "manual" - assert event["spans"][0]["origin"] == "auto.ai.cohere" +# --- V2 Embed --- -def test_span_origin_embed(sentry_init, capture_events): +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_embed(sentry_init, capture_events, send_default_pii, include_prompts): sentry_init( - integrations=[CohereIntegration()], + integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, + send_default_pii=send_default_pii, ) events = capture_events() - client = Client(api_key="z") + client = ClientV2(api_key="z") HTTPXClient.request = mock.Mock( return_value=httpx.Response( 200, @@ -285,7 +452,7 @@ def test_span_origin_embed(sentry_init, capture_events): "response_type": "embeddings_floats", "id": "1", "texts": ["hello"], - "embeddings": [[1.0, 2.0, 3.0]], + "embeddings": {"float": [[1.0, 2.0, 3.0]]}, "meta": { "billed_units": { "input_tokens": 10, @@ -296,9 +463,25 @@ def test_span_origin_embed(sentry_init, capture_events): ) with start_transaction(name="cohere tx"): - client.embed(texts=["hello"], model="text-embedding-3-large") + client.embed( + texts=["hello"], + model="embed-english-v3.0", + input_type="search_document", + embedding_types=["float"], + ) - (event,) = events + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.embeddings" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings" + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] + else: + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] - assert event["contexts"]["trace"]["origin"] == "manual" - assert event["spans"][0]["origin"] == "auto.ai.cohere" + assert span["data"]["gen_ai.usage.input_tokens"] == 10 + assert span["data"]["gen_ai.usage.total_tokens"] == 10