Skip to content

Commit

Permalink
feat(integrations): Add async support for ai_track decorator
Browse files Browse the repository at this point in the history
This commit adds capabilities to support async functions for the `ai_track` decorator
  • Loading branch information
czyber authored and arjenzorgdoc committed Sep 30, 2024
1 parent 66dc686 commit d0be54d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
38 changes: 35 additions & 3 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from functools import wraps

import sentry_sdk.utils
Expand Down Expand Up @@ -26,8 +27,7 @@ def ai_track(description, **span_kwargs):
# type: (str, Any) -> Callable[..., Any]
def decorator(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def wrapped(*args, **kwargs):
def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
Expand Down Expand Up @@ -56,7 +56,39 @@ def wrapped(*args, **kwargs):
_ai_pipeline_name.set(None)
return res

return wrapped
async def async_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")

with start_span(description=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
span.set_tag(k, v)
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
if curr_pipeline:
span.set_data("ai.pipeline.name", curr_pipeline)
return await f(*args, **kwargs)
else:
_ai_pipeline_name.set(description)
try:
res = await f(*args, **kwargs)
except Exception as e:
event, hint = sentry_sdk.utils.event_from_exception(
e,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "ai_monitoring", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
raise e from None
finally:
_ai_pipeline_name.set(None)
return res

if inspect.iscoroutinefunction(f):
return wraps(f)(async_wrapped)
else:
return wraps(f)(sync_wrapped)

return decorator

Expand Down
62 changes: 62 additions & 0 deletions tests/test_ai_monitoring.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import sentry_sdk
from sentry_sdk.ai.monitoring import ai_track

Expand Down Expand Up @@ -57,3 +59,63 @@ def pipeline():
assert ai_pipeline_span["tags"]["user"] == "colin"
assert ai_pipeline_span["data"]["some_data"] == "value"
assert ai_run_span["description"] == "my tool"


@pytest.mark.asyncio
async def test_ai_track_async(sentry_init, capture_events):
sentry_init(traces_sample_rate=1.0)
events = capture_events()

@ai_track("my async tool")
async def async_tool(**kwargs):
pass

@ai_track("some async test pipeline")
async def async_pipeline():
await async_tool()

with sentry_sdk.start_transaction():
await async_pipeline()

transaction = events[0]
assert transaction["type"] == "transaction"
assert len(transaction["spans"]) == 2
spans = transaction["spans"]

ai_pipeline_span = spans[0] if spans[0]["op"] == "ai.pipeline" else spans[1]
ai_run_span = spans[0] if spans[0]["op"] == "ai.run" else spans[1]

assert ai_pipeline_span["description"] == "some async test pipeline"
assert ai_run_span["description"] == "my async tool"


@pytest.mark.asyncio
async def test_ai_track_async_with_tags(sentry_init, capture_events):
sentry_init(traces_sample_rate=1.0)
events = capture_events()

@ai_track("my async tool")
async def async_tool(**kwargs):
pass

@ai_track("some async test pipeline")
async def async_pipeline():
await async_tool()

with sentry_sdk.start_transaction():
await async_pipeline(
sentry_tags={"user": "czyber"}, sentry_data={"some_data": "value"}
)

transaction = events[0]
assert transaction["type"] == "transaction"
assert len(transaction["spans"]) == 2
spans = transaction["spans"]

ai_pipeline_span = spans[0] if spans[0]["op"] == "ai.pipeline" else spans[1]
ai_run_span = spans[0] if spans[0]["op"] == "ai.run" else spans[1]

assert ai_pipeline_span["description"] == "some async test pipeline"
assert ai_pipeline_span["tags"]["user"] == "czyber"
assert ai_pipeline_span["data"]["some_data"] == "value"
assert ai_run_span["description"] == "my async tool"

0 comments on commit d0be54d

Please sign in to comment.