Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _collect_ai_data(
model: "str | None",
usage: "_RecordedUsage",
content_blocks: "list[str]",
) -> "tuple[str | None, _RecordedUsage, list[str]]":
response_id: "str | None" = None,
) -> "tuple[str | None, _RecordedUsage, list[str], str | None]":
"""
Collect model information, token usage, and collect content blocks from the AI streaming response.
"""
Expand All @@ -146,6 +147,7 @@ def _collect_ai_data(
# https://github.com/anthropics/anthropic-sdk-python/blob/9c485f6966e10ae0ea9eabb3a921d2ea8145a25b/src/anthropic/lib/streaming/_messages.py#L433-L518
if event.type == "message_start":
model = event.message.model or model
response_id = event.message.id

incoming_usage = event.message.usage
usage.output_tokens = incoming_usage.output_tokens
Expand All @@ -162,6 +164,7 @@ def _collect_ai_data(
model,
usage,
content_blocks,
response_id,
)

# Counterintuitive, but message_delta contains cumulative token counts :)
Expand Down Expand Up @@ -190,12 +193,14 @@ def _collect_ai_data(
model,
usage,
content_blocks,
response_id,
)

return (
model,
usage,
content_blocks,
response_id,
)


Expand Down Expand Up @@ -348,10 +353,13 @@ def _set_output_data(
cache_write_input_tokens: "int | None",
content_blocks: "list[Any]",
finish_span: bool = False,
response_id: "str | None" = None,
) -> None:
"""
Set output data for the span based on the AI response."""
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model)
if response_id is not None:
span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, response_id)
if should_send_default_pii() and integration.include_prompts:
output_messages: "dict[str, list[Any]]" = {
"response": [],
Expand Down Expand Up @@ -443,6 +451,7 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=content_blocks,
finish_span=True,
response_id=getattr(result, "id", None),
)

# Streaming response
Expand All @@ -453,17 +462,20 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
model = None
usage = _RecordedUsage()
content_blocks: "list[str]" = []
response_id = None

for event in old_iterator:
(
model,
usage,
content_blocks,
response_id,
) = _collect_ai_data(
event,
model,
usage,
content_blocks,
response_id,
)
yield event

Expand All @@ -485,23 +497,27 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
cache_write_input_tokens=usage.cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
response_id=response_id,
)

async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model = None
usage = _RecordedUsage()
content_blocks: "list[str]" = []
response_id = None

async for event in old_iterator:
(
model,
usage,
content_blocks,
response_id,
) = _collect_ai_data(
event,
model,
usage,
content_blocks,
response_id,
)
yield event

Expand All @@ -523,6 +539,7 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
cache_write_input_tokens=usage.cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
response_id=response_id,
)

if str(type(result._iterator)) == "<class 'async_generator'>":
Expand Down
9 changes: 7 additions & 2 deletions tests/integrations/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def __call__(self, *args, **kwargs):
ANTHROPIC_VERSION = package_version("anthropic")

EXAMPLE_MESSAGE = Message(
id="id",
id="msg_01XFDUDYJgAACzvnptvVoYEL",
model="model",
role="assistant",
content=[TextBlock(type="text", text="Hi, I'm Claude.")],
Expand Down Expand Up @@ -134,6 +134,7 @@ def test_nonstreaming_create_message(
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"


@pytest.mark.asyncio
Expand Down Expand Up @@ -204,6 +205,7 @@ async def test_nonstreaming_create_message_async(
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -306,6 +308,7 @@ def test_streaming_create_message(
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"


@pytest.mark.asyncio
Expand Down Expand Up @@ -411,6 +414,7 @@ async def test_streaming_create_message_async(
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"


@pytest.mark.skipif(
Expand Down Expand Up @@ -852,13 +856,14 @@ def test_collect_ai_data_with_input_json_delta():

content_blocks = []

model, new_usage, new_content_blocks = _collect_ai_data(
model, new_usage, new_content_blocks, response_id = _collect_ai_data(
event, model, usage, content_blocks
)
assert model is None
assert new_usage.input_tokens == usage.input_tokens
assert new_usage.output_tokens == usage.output_tokens
assert new_content_blocks == ["test"]
assert response_id is None


@pytest.mark.skipif(
Expand Down
Loading