diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 1349f6ae0..79502e635 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -18,6 +18,7 @@ import pydantic from opentelemetry import context, trace from opentelemetry.util._decorator import _AgnosticContextManager +from typing_extensions import Protocol from langfuse import propagate_attributes from langfuse._client.attributes import LangfuseOtelSpanAttributes @@ -142,12 +143,26 @@ def keys(self) -> List[str]: return list(self._contexts.keys()) +class LangchainGenerationMetadataExtractor(Protocol): + def __call__( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Optional[Dict[str, Any]]: ... + + class LangchainCallbackHandler(LangchainBaseCallbackHandler): def __init__( self, *, public_key: Optional[str] = None, trace_context: Optional[TraceContext] = None, + generation_metadata_extractor: Optional[ + LangchainGenerationMetadataExtractor + ] = None, ) -> None: """Initialize the LangchainCallbackHandler. @@ -157,6 +172,9 @@ def __init__( setting a custom trace id for the root LangChain run. Pass a `TraceContext` dict, e.g. `{"trace_id": ""}` (and optionally `{"parent_span_id": ""}`) to link the trace to an upstream system. + generation_metadata_extractor: Optional callable that receives the LangChain `LLMResult`, + `run_id`, `parent_run_id`, and callback kwargs, and returns metadata to merge with + the ended Langfuse generation observation. Example: Use a custom trace id without context managers: @@ -183,6 +201,8 @@ def __init__( self._prompt_to_parent_run_map: Dict[UUID, Any] = {} self._updated_completion_start_time_memo: Set[UUID] = set() self._trace_context = trace_context + self._generation_metadata_extractor = generation_metadata_extractor + self._generation_metadata_by_run_id: Dict[UUID, Dict[str, Any]] = {} self._pending_resume_trace_contexts = _PendingResumeTraceContextStore( MAX_PENDING_RESUME_TRACE_CONTEXTS ) @@ -191,6 +211,25 @@ def __init__( self.last_trace_id: Optional[str] = None + def get_generation_metadata( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Optional[Dict[str, Any]]: + """Return response-derived metadata for a LangChain generation observation.""" + if self._generation_metadata_extractor is None: + return None + + return self._generation_metadata_extractor( + response, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) + def on_llm_new_token( self, token: str, @@ -1191,18 +1230,18 @@ def __on_llm_action( current_parent_run_id ) + observation_metadata = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + # If llm is run isolated and outside chain, keep trace attributes + keep_langfuse_trace_attributes=True if parent_run_id is None else False, + ) + content = { "name": self.get_langchain_run_name(serialized, **kwargs), "input": prompts, - "metadata": self._get_langchain_observation_metadata( - parent_run_id=parent_run_id, - tags=tags, - metadata=metadata, - # If llm is run isolated and outside chain, keep trace attributes - keep_langfuse_trace_attributes=True - if parent_run_id is None - else False, - ), + "metadata": observation_metadata, "model": model_name, "model_parameters": self._parse_model_parameters(kwargs), "prompt": registered_prompt, @@ -1220,6 +1259,8 @@ def __on_llm_action( as_type="generation", **content ) # type: ignore self._attach_observation(run_id, generation) + if observation_metadata is not None: + self._generation_metadata_by_run_id[run_id] = observation_metadata self.last_trace_id = self._runs[run_id].trace_id @@ -1314,14 +1355,30 @@ def on_llm_end( model = _parse_model(response) generation = self._detach_observation(run_id) + initial_metadata = self._generation_metadata_by_run_id.pop(run_id, {}) if generation is not None: + try: + generation_metadata = self.get_generation_metadata( + response, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) + except Exception as e: + langfuse_logger.exception(e) + generation_metadata = None + generation.update( output=extracted_response, usage=llm_usage, usage_details=llm_usage, input=kwargs.get("inputs"), model=model, + metadata={ + **initial_metadata, + **(generation_metadata or {}), + }, ).end() except Exception as e: @@ -1346,6 +1403,7 @@ def on_llm_error( self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) generation = self._detach_observation(run_id) + self._generation_metadata_by_run_id.pop(run_id, None) if generation is not None: level, status_message = self._get_error_level_and_status_message(error) @@ -1380,10 +1438,12 @@ def _reset(self, root_run_id: UUID) -> None: root_run_state = self._root_run_states.pop(run_state.root_run_id, None) if root_run_state is None: self._run_states.pop(root_run_id, None) + self._generation_metadata_by_run_id.pop(root_run_id, None) return for run_id in root_run_state.run_ids: self._run_states.pop(run_id, None) + self._generation_metadata_by_run_id.pop(run_id, None) def _exit_propagation_context(self, run_id: UUID) -> None: root_run_state = self._get_root_run_state(run_id) diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py index 27298342c..84b69d6ee 100644 --- a/tests/unit/test_langchain.py +++ b/tests/unit/test_langchain.py @@ -97,6 +97,150 @@ def test_chat_model_callback_exports_generation_span( } +def test_chat_model_callback_adds_response_derived_generation_metadata( + langfuse_memory_client, get_span +): + response = ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content="bonjour", + response_metadata={ + "headers": { + "x-request-id": "req_123", + }, + }, + ), + text="bonjour", + ) + ], + llm_output={"model_name": "gpt-4o-mini"}, + ) + + def generation_metadata_extractor( + response, *, run_id, parent_run_id=None, **kwargs + ): + generation = response.generations[-1][-1] + headers = generation.message.response_metadata.get("headers", {}) + + return { + "provider_request_id": headers.get("x-request-id"), + "run_id": str(run_id), + "run_has_value": run_id is not None, + } + + with patch.object(ChatOpenAI, "_generate", return_value=response): + handler = CallbackHandler( + generation_metadata_extractor=generation_metadata_extractor + ) + + with langfuse_memory_client.start_as_current_observation(name="parent"): + ChatOpenAI(api_key="test", temperature=0).invoke( + [HumanMessage(content="hello")], + config={ + "callbacks": [handler], + "metadata": {"initial_metadata": "kept"}, + }, + ) + + langfuse_memory_client.flush() + generation_span = get_span("ChatOpenAI") + + assert ( + generation_span.attributes[ + f"{LangfuseOtelSpanAttributes.OBSERVATION_METADATA}.initial_metadata" + ] + == "kept" + ) + assert ( + generation_span.attributes[ + f"{LangfuseOtelSpanAttributes.OBSERVATION_METADATA}.provider_request_id" + ] + == "req_123" + ) + assert ( + generation_span.attributes[ + f"{LangfuseOtelSpanAttributes.OBSERVATION_METADATA}.run_has_value" + ] + is True + ) + + +def test_chat_model_callback_supports_generation_metadata_subclass_hook( + langfuse_memory_client, get_span +): + class CustomCallbackHandler(CallbackHandler): + def get_generation_metadata( + self, response, *, run_id, parent_run_id=None, **kwargs + ): + return {"provider_request_id": "subclass_req_123"} + + response = ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content="bonjour"), text="bonjour") + ], + llm_output={"model_name": "gpt-4o-mini"}, + ) + + with patch.object(ChatOpenAI, "_generate", return_value=response): + handler = CustomCallbackHandler() + + with langfuse_memory_client.start_as_current_observation(name="parent"): + ChatOpenAI(api_key="test", temperature=0).invoke( + [HumanMessage(content="hello")], + config={"callbacks": [handler]}, + ) + + langfuse_memory_client.flush() + generation_span = get_span("ChatOpenAI") + + assert ( + generation_span.attributes[ + f"{LangfuseOtelSpanAttributes.OBSERVATION_METADATA}.provider_request_id" + ] + == "subclass_req_123" + ) + + +def test_chat_model_callback_ends_generation_when_metadata_hook_fails( + langfuse_memory_client, get_span, json_attr +): + response = ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content="bonjour"), text="bonjour") + ], + llm_output={"model_name": "gpt-4o-mini"}, + ) + + def generation_metadata_extractor(response, *, run_id, parent_run_id=None, **kwargs): + raise RuntimeError("metadata unavailable") + + with patch.object(ChatOpenAI, "_generate", return_value=response): + handler = CallbackHandler( + generation_metadata_extractor=generation_metadata_extractor + ) + + with langfuse_memory_client.start_as_current_observation(name="parent"): + ChatOpenAI(api_key="test", temperature=0).invoke( + [HumanMessage(content="hello")], + config={"callbacks": [handler]}, + ) + + langfuse_memory_client.flush() + generation_span = get_span("ChatOpenAI") + + assert ( + generation_span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_TYPE] + == "generation" + ) + assert json_attr( + generation_span, LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT + ) == { + "role": "assistant", + "content": "bonjour", + } + + def test_llm_callback_exports_generation_span(langfuse_memory_client, get_span): response = LLMResult( generations=[[Generation(text="sockzilla")]],