Skip to content

Commit 53859e4

Browse files
.
1 parent db5dbb9 commit 53859e4

1 file changed

Lines changed: 117 additions & 97 deletions

File tree

sentry_sdk/integrations/anthropic.py

Lines changed: 117 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import json
3+
import inspect
34
from collections.abc import Iterable
45
from functools import wraps
56
from typing import TYPE_CHECKING
@@ -37,6 +38,7 @@
3738
except ImportError:
3839
Omit = None
3940

41+
from anthropic import Stream, AsyncStream
4042
from anthropic.resources import AsyncMessages, Messages
4143

4244
if TYPE_CHECKING:
@@ -45,11 +47,10 @@
4547
raise DidNotEnable("Anthropic not installed")
4648

4749
if TYPE_CHECKING:
48-
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
50+
from typing import Any, AsyncIterator, Iterator, List, Optional, Union, Callable
4951
from sentry_sdk.tracing import Span
5052
from sentry_sdk._types import TextPart
5153

52-
from anthropic import AsyncStream
5354
from anthropic.types import RawMessageStreamEvent
5455

5556

@@ -75,6 +76,117 @@ def setup_once() -> None:
7576
Messages.create = _wrap_message_create(Messages.create)
7677
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
7778

79+
Stream.__iter__ = _wrap_stream_iter(Stream.__iter__)
80+
AsyncStream.__aiter__ = _wrap_async_stream_aiter(AsyncStream.__aiter__)
81+
82+
83+
def _wrap_stream_iter(
84+
f: "Callable[..., Iterator[RawMessageStreamEvent]]",
85+
) -> "Callable[..., Iterator[RawMessageStreamEvent]]":
86+
@wraps(f)
87+
def _patched_iter(self: "Stream") -> "Iterator[RawMessageStreamEvent]":
88+
if not hasattr(self, "_sentry_span"):
89+
for event in f(self):
90+
yield event
91+
92+
model = None
93+
usage = _RecordedUsage()
94+
content_blocks: "list[str]" = []
95+
96+
for event in f(self):
97+
(
98+
model,
99+
usage,
100+
content_blocks,
101+
) = _collect_ai_data(
102+
event,
103+
model,
104+
usage,
105+
content_blocks,
106+
)
107+
yield event
108+
109+
# Anthropic's input_tokens excludes cached/cache_write tokens.
110+
# Normalize to total input tokens for correct cost calculations.
111+
total_input = (
112+
usage.input_tokens
113+
+ (usage.cache_read_input_tokens or 0)
114+
+ (usage.cache_write_input_tokens or 0)
115+
)
116+
117+
span = self._sentry_span
118+
integration = self._integration
119+
120+
_set_output_data(
121+
span=span,
122+
integration=integration,
123+
model=model,
124+
input_tokens=total_input,
125+
output_tokens=usage.output_tokens,
126+
cache_read_input_tokens=usage.cache_read_input_tokens,
127+
cache_write_input_tokens=usage.cache_write_input_tokens,
128+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
129+
finish_span=True,
130+
)
131+
132+
return f(self)
133+
134+
return _patched_iter
135+
136+
137+
def _wrap_async_stream_aiter(
138+
f: "Callable[..., AsyncIterator[RawMessageStreamEvent]]",
139+
) -> "Callable[..., AsyncIterator[RawMessageStreamEvent]]":
140+
@wraps(f)
141+
async def _patched_aiter(
142+
self: "AsyncStream",
143+
) -> "AsyncIterator[RawMessageStreamEvent]":
144+
if not hasattr(self, "_sentry_span"):
145+
async for event in f(self):
146+
yield event
147+
148+
model = None
149+
usage = _RecordedUsage()
150+
content_blocks: "list[str]" = []
151+
152+
async for event in f(self):
153+
(
154+
model,
155+
usage,
156+
content_blocks,
157+
) = _collect_ai_data(
158+
event,
159+
model,
160+
usage,
161+
content_blocks,
162+
)
163+
yield event
164+
165+
# Anthropic's input_tokens excludes cached/cache_write tokens.
166+
# Normalize to total input tokens for correct cost calculations.
167+
total_input = (
168+
usage.input_tokens
169+
+ (usage.cache_read_input_tokens or 0)
170+
+ (usage.cache_write_input_tokens or 0)
171+
)
172+
173+
span = self._sentry_span
174+
integration = self._integration
175+
176+
_set_output_data(
177+
span=span,
178+
integration=integration,
179+
model=model,
180+
input_tokens=total_input,
181+
output_tokens=usage.output_tokens,
182+
cache_read_input_tokens=usage.cache_read_input_tokens,
183+
cache_write_input_tokens=usage.cache_write_input_tokens,
184+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
185+
finish_span=True,
186+
)
187+
188+
return _patched_aiter
189+
78190

79191
def _capture_exception(exc: "Any") -> None:
80192
set_span_errored()
@@ -392,98 +504,6 @@ def _set_output_data(
392504
span.__exit__(None, None, None)
393505

394506

395-
def _patch_streaming_response_iterator(
396-
result: "AsyncStream[RawMessageStreamEvent]",
397-
span: "sentry_sdk.tracing.Span",
398-
integration: "AnthropicIntegration",
399-
) -> None:
400-
"""
401-
Responsible for closing the `gen_ai.chat` span and setting attributes acquired during response consumption.
402-
"""
403-
old_iterator = result._iterator
404-
405-
def new_iterator() -> "Iterator[MessageStreamEvent]":
406-
model = None
407-
usage = _RecordedUsage()
408-
content_blocks: "list[str]" = []
409-
410-
for event in old_iterator:
411-
(
412-
model,
413-
usage,
414-
content_blocks,
415-
) = _collect_ai_data(
416-
event,
417-
model,
418-
usage,
419-
content_blocks,
420-
)
421-
yield event
422-
423-
# Anthropic's input_tokens excludes cached/cache_write tokens.
424-
# Normalize to total input tokens for correct cost calculations.
425-
total_input = (
426-
usage.input_tokens
427-
+ (usage.cache_read_input_tokens or 0)
428-
+ (usage.cache_write_input_tokens or 0)
429-
)
430-
431-
_set_output_data(
432-
span=span,
433-
integration=integration,
434-
model=model,
435-
input_tokens=total_input,
436-
output_tokens=usage.output_tokens,
437-
cache_read_input_tokens=usage.cache_read_input_tokens,
438-
cache_write_input_tokens=usage.cache_write_input_tokens,
439-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
440-
finish_span=True,
441-
)
442-
443-
async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
444-
model = None
445-
usage = _RecordedUsage()
446-
content_blocks: "list[str]" = []
447-
448-
async for event in old_iterator:
449-
(
450-
model,
451-
usage,
452-
content_blocks,
453-
) = _collect_ai_data(
454-
event,
455-
model,
456-
usage,
457-
content_blocks,
458-
)
459-
yield event
460-
461-
# Anthropic's input_tokens excludes cached/cache_write tokens.
462-
# Normalize to total input tokens for correct cost calculations.
463-
total_input = (
464-
usage.input_tokens
465-
+ (usage.cache_read_input_tokens or 0)
466-
+ (usage.cache_write_input_tokens or 0)
467-
)
468-
469-
_set_output_data(
470-
span=span,
471-
integration=integration,
472-
model=model,
473-
input_tokens=total_input,
474-
output_tokens=usage.output_tokens,
475-
cache_read_input_tokens=usage.cache_read_input_tokens,
476-
cache_write_input_tokens=usage.cache_write_input_tokens,
477-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
478-
finish_span=True,
479-
)
480-
481-
if str(type(result._iterator)) == "<class 'async_generator'>":
482-
result._iterator = new_iterator_async()
483-
else:
484-
result._iterator = new_iterator()
485-
486-
487507
def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
488508
integration = kwargs.pop("integration")
489509
if integration is None:
@@ -510,9 +530,9 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
510530

511531
result = yield f, args, kwargs
512532

513-
is_streaming_response = kwargs.get("stream", False)
514-
if is_streaming_response:
515-
_patch_streaming_response_iterator(result, span, integration)
533+
if isinstance(result, Stream) or isinstance(result, AsyncStream):
534+
result._sentry_span = span
535+
result._integration = integration
516536
return result
517537

518538
with capture_internal_exceptions():

0 commit comments

Comments
 (0)