diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 72252a1d..4b50b9b9 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -62,6 +62,7 @@ def is_equal(expected, output): from .audit import * from .auto import auto_instrument as auto_instrument +from .dataset_pipeline import * from .framework import * from .framework2 import * from .functions.invoke import * diff --git a/py/src/braintrust/dataset_pipeline.py b/py/src/braintrust/dataset_pipeline.py new file mode 100644 index 00000000..4cbbbf4f --- /dev/null +++ b/py/src/braintrust/dataset_pipeline.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass +from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar + +from typing_extensions import NotRequired, TypedDict + +from .generated_types import ObjectReference +from .logger import Metadata +from .trace import Trace + + +DatasetPipelineScope: TypeAlias = Literal["span", "trace"] + + +class DatasetPipelineSource(TypedDict, total=False): + project_id: str + project_name: str + org_name: str + filter: str + scope: DatasetPipelineScope + + +@dataclass(frozen=True) +class PipelineSource: + filter: str | None = None + scope: DatasetPipelineScope | None = None + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + + def as_dict(self) -> DatasetPipelineSource: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "filter": self.filter, + "scope": self.scope, + } + ) + + +class DatasetPipelineTarget(TypedDict): + dataset_name: str + project_id: NotRequired[str] + project_name: NotRequired[str] + org_name: NotRequired[str] + description: NotRequired[str] + metadata: NotRequired[Metadata] + + +@dataclass(frozen=True) +class PipelineTarget: + dataset_name: str + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + description: str | None = None + metadata: Metadata | None = None + + def as_dict(self) -> DatasetPipelineTarget: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "dataset_name": self.dataset_name, + "description": self.description, + "metadata": self.metadata, + } + ) + + +class DatasetPipelineRow(TypedDict, total=False): + id: str + input: Any | None + expected: Any | None + tags: Sequence[str] | None + metadata: Metadata | None + origin: ObjectReference + + +Row = TypeVar("Row", bound=DatasetPipelineRow) + + +class DatasetPipelineTransformArgs(TypedDict, total=False): + input: Any | None + output: Any | None + metadata: Metadata | None + expected: Any | None + trace: Trace + + +DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None +DatasetPipelineSourceLike: TypeAlias = DatasetPipelineSource | PipelineSource +DatasetPipelineTargetLike: TypeAlias = DatasetPipelineTarget | PipelineTarget + + +def _drop_none(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} + + +def _normalize_source(source: DatasetPipelineSourceLike) -> DatasetPipelineSource: + if isinstance(source, PipelineSource): + return source.as_dict() + return dict(source) + + +def _normalize_target(target: DatasetPipelineTargetLike) -> DatasetPipelineTarget: + if isinstance(target, PipelineTarget): + return target.as_dict() + return dict(target) + + +class DatasetPipelineTransform(Protocol[Row]): + def __call__( + self, + input: Any | None = None, + output: Any | None = None, + metadata: Metadata | None = None, + expected: Any | None = None, + trace: Trace | None = None, + ) -> DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]]: ... + + +@dataclass(frozen=True) +class DatasetPipelineDefinition(Generic[Row]): + source: DatasetPipelineSource + transform: DatasetPipelineTransform[Row] + target: DatasetPipelineTarget + name: str | None = None + + +_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = [] + + +def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]: + return list(_DATASET_PIPELINES) + + +def is_dataset_pipeline_definition(value: object) -> bool: + return isinstance(value, DatasetPipelineDefinition) + + +def DatasetPipeline( + name: str | None = None, + *, + source: DatasetPipelineSourceLike, + transform: DatasetPipelineTransform[DatasetPipelineRow], + target: DatasetPipelineTargetLike, +) -> DatasetPipelineDefinition[DatasetPipelineRow]: + definition = DatasetPipelineDefinition( + name=name, + source=_normalize_source(source), + transform=transform, + target=_normalize_target(target), + ) + _DATASET_PIPELINES.append(definition) + return definition diff --git a/py/src/braintrust/test_trace.py b/py/src/braintrust/test_trace.py index 577d4e58..ebfea43c 100644 --- a/py/src/braintrust/test_trace.py +++ b/py/src/braintrust/test_trace.py @@ -41,6 +41,35 @@ async def fetch_fn(span_type): assert len(result) == 3 assert {s.span_id for s in result} == {"span-1", "span-2", "span-3"} + @pytest.mark.asyncio + async def test_fetch_preserves_span_result_fields(self): + """Test that fetched spans preserve fields needed for full trace attachments.""" + mock_spans = [ + make_span( + "span-1", + "tool", + expected={"answer": "ok"}, + error={"message": "boom"}, + metrics={"start": 1, "end": 2}, + scores={"quality": 0}, + tags=["debug"], + ) + ] + + async def fetch_fn(span_type): + del span_type + return mock_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + result = await fetcher.get_spans() + + assert result[0].expected == {"answer": "ok"} + assert result[0].error == {"message": "boom"} + assert result[0].metrics == {"start": 1, "end": 2} + assert result[0].scores == {"quality": 0} + assert result[0].tags == ["debug"] + assert result[0].to_dict()["error"] == {"message": "boom"} + @pytest.mark.asyncio async def test_fetch_specific_span_types(self): """Test fetching specific span types when filter specified.""" diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py index 24bcefa2..c254fbda 100644 --- a/py/src/braintrust/trace.py +++ b/py/src/braintrust/trace.py @@ -22,17 +22,27 @@ def __init__( input: Any | None = None, output: Any | None = None, metadata: Metadata | None = None, + expected: Any | None = None, + error: Any | None = None, + scores: Any | None = None, + metrics: Any | None = None, span_id: str | None = None, span_parents: list[str] | None = None, span_attributes: dict[str, Any] | None = None, + tags: list[str] | None = None, **kwargs: Any, ): self.input = input self.output = output self.metadata = metadata + self.expected = expected + self.error = error + self.scores = scores + self.metrics = metrics self.span_id = span_id self.span_parents = span_parents self.span_attributes = span_attributes + self.tags = tags # Store any additional fields for key, value in kwargs.items(): setattr(self, key, value) @@ -64,9 +74,10 @@ def __init__( root_span_id: str, state: BraintrustState, span_type_filter: list[str] | None = None, + include_scorers: bool = False, ): # Build the filter expression for root_span_id and optionally span_attributes.type - filter_expr = self._build_filter(root_span_id, span_type_filter) + filter_expr = self._build_filter(root_span_id, span_type_filter, include_scorers) super().__init__( object_type=object_type, @@ -76,7 +87,11 @@ def __init__( self._state = state @staticmethod - def _build_filter(root_span_id: str, span_type_filter: list[str] | None = None) -> dict[str, Any]: + def _build_filter( + root_span_id: str, + span_type_filter: list[str] | None = None, + include_scorers: bool = False, + ) -> dict[str, Any]: """Build BTQL filter expression.""" children = [ # Base filter: root_span_id = 'value' @@ -85,23 +100,32 @@ def _build_filter(root_span_id: str, span_type_filter: list[str] | None = None) "left": {"op": "ident", "name": ["root_span_id"]}, "right": {"op": "literal", "value": root_span_id}, }, - # Exclude span_attributes.purpose = 'scorer' - { - "op": "or", - "children": [ - { - "op": "isnull", - "expr": {"op": "ident", "name": ["span_attributes", "purpose"]}, - }, - { - "op": "ne", - "left": {"op": "ident", "name": ["span_attributes", "purpose"]}, - "right": {"op": "literal", "value": "scorer"}, - }, - ], - }, ] + if not include_scorers: + children.append( + { + "op": "or", + "children": [ + { + "op": "isnull", + "expr": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + }, + { + "op": "ne", + "left": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + "right": {"op": "literal", "value": "scorer"}, + }, + ], + } + ) + # If span type filter specified, add it if span_type_filter and len(span_type_filter) > 0: children.append( @@ -123,6 +147,7 @@ def _get_state(self) -> BraintrustState: SpanFetchFn = Callable[[list[str] | None], Awaitable[list[SpanData]]] +SpanFetchWithOptionsFn = Callable[[list[str] | None, bool], Awaitable[list[SpanData]]] class GetThreadOptions(TypedDict, total=False): @@ -152,7 +177,14 @@ def __init__( if fetch_fn is not None: # Direct fetch function injection (for testing) - self._fetch_fn = fetch_fn + async def _fetch_fn( + span_type: list[str] | None, + include_scorers: bool = False, + ) -> list[SpanData]: + del include_scorers + return await fetch_fn(span_type) + + self._fetch_fn: SpanFetchWithOptionsFn = _fetch_fn else: # Standard constructor with SpanFetcher if object_type is None or object_id is None or root_span_id is None or get_state is None: @@ -160,7 +192,10 @@ def __init__( "Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state" ) - async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: + async def _fetch_fn( + span_type: list[str] | None, + include_scorers: bool = False, + ) -> list[SpanData]: state = await get_state() fetcher = SpanFetcher( object_type=object_type, @@ -168,21 +203,17 @@ async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: root_span_id=root_span_id, state=state, span_type_filter=span_type, + include_scorers=include_scorers, ) rows = list(fetcher.fetch()) - # Filter out scorer spans - filtered = [ - row - for row in rows - if not ( - isinstance(row.get("span_attributes"), dict) - and row.get("span_attributes", {}).get("purpose") == "scorer" - ) - ] return [ SpanData( input=row.get("input"), output=row.get("output"), + expected=row.get("expected"), + error=row.get("error"), + scores=row.get("scores"), + metrics=row.get("metrics"), metadata=row.get("metadata"), span_id=row.get("span_id"), span_parents=row.get("span_parents"), @@ -191,22 +222,34 @@ async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: _xact_id=row.get("_xact_id"), _pagination_key=row.get("_pagination_key"), root_span_id=row.get("root_span_id"), + is_root=row.get("is_root"), + created=row.get("created"), + tags=row.get("tags"), ) - for row in filtered + for row in rows ] self._fetch_fn = _fetch_fn - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Get spans, using cache when possible. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans """ + if include_scorers: + return await self._fetch_fn(span_type, True) + # If we've fetched all spans, just filter from cache if self._all_fetched: return self._get_from_cache(span_type) @@ -231,7 +274,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: async def _fetch_spans(self, span_type: list[str] | None) -> None: """Fetch spans from the server.""" - spans = await self._fetch_fn(span_type) + spans = await self._fetch_fn(span_type, False) for span in spans: span_attrs = span.span_attributes or {} @@ -267,12 +310,18 @@ def get_configuration(self) -> dict[str, str]: """Get the trace configuration (object_type, object_id, root_span_id).""" ... - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all spans for this root span. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -352,7 +401,12 @@ def get_configuration(self) -> dict[str, str]: "root_span_id": self._root_span_id, } - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all rows for this root span from its parent object (experiment or project logs). First checks the local span cache for recently logged spans, then falls @@ -360,6 +414,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -368,7 +423,11 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id) if cached_spans and len(cached_spans) > 0: # Filter by purpose - spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"] + spans = [ + span + for span in cached_spans + if include_scorers or not (span.span_attributes or {}).get("purpose") == "scorer" + ] # Filter by span type if requested if span_type and len(span_type) > 0: @@ -379,16 +438,21 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: SpanData( input=span.input, output=span.output, + expected=getattr(span, "expected", None), + error=getattr(span, "error", None), + scores=getattr(span, "scores", None), + metrics=getattr(span, "metrics", None), metadata=span.metadata, span_id=span.span_id, span_parents=span.span_parents, span_attributes=span.span_attributes, + tags=getattr(span, "tags", None), ) for span in spans ] # Fall back to CachedSpanFetcher for BTQL fetching with caching - return await self._cached_fetcher.get_spans(span_type) + return await self._cached_fetcher.get_spans(span_type, include_scorers=include_scorers) async def get_thread(self, options: GetThreadOptions | None = None) -> list[Any]: """