diff --git a/src/uipath_langchain/agent/guardrails/guardrail_nodes.py b/src/uipath_langchain/agent/guardrails/guardrail_nodes.py index 82c0cf6de..964b061c2 100644 --- a/src/uipath_langchain/agent/guardrails/guardrail_nodes.py +++ b/src/uipath_langchain/agent/guardrails/guardrail_nodes.py @@ -4,6 +4,7 @@ from typing import Any, Callable from langgraph.types import Command +from pydantic import BaseModel from uipath.core.guardrails import ( DeterministicGuardrail, DeterministicGuardrailsService, @@ -26,18 +27,35 @@ get_message_content, ) from uipath_langchain.agent.react.types import AgentGuardrailsGraphState +from uipath_langchain.agent.react.utils import extract_input_data_from_state from ..exceptions import AgentRuntimeError, AgentRuntimeErrorCode logger = logging.getLogger(__name__) +def _resolve_agent_input( + state: AgentGuardrailsGraphState, + input_schema: type[BaseModel] | None, +) -> dict[str, Any] | None: + if input_schema is None: + return None + try: + return extract_input_data_from_state(state, input_schema) + except Exception: + # The state may not yet carry agent-input fields (e.g., very early + # subgraphs, or schemas whose required fields aren't seeded yet); fall + # back to "no agent_input available" rather than crashing the run. + return None + + def _evaluate_deterministic_guardrail( state: AgentGuardrailsGraphState, guardrail: DeterministicGuardrail, execution_stage: ExecutionStage, input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]], output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None, + input_schema: type[BaseModel] | None = None, ): """Evaluate deterministic guardrail. @@ -47,6 +65,10 @@ def _evaluate_deterministic_guardrail( execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION). input_data_extractor: Function to extract input data from state. output_data_extractor: Function to extract output data from state (optional). + input_schema: Optional input schema; when provided, the agent's + validated input parameters are extracted from state and passed to + pre-execution evaluation so rules can reference + ``FieldSource.AGENT_INPUT``. Returns: The guardrail evaluation result. @@ -56,7 +78,9 @@ def _evaluate_deterministic_guardrail( if execution_stage == ExecutionStage.PRE_EXECUTION: return service.evaluate_pre_deterministic_guardrail( - input_data=input_data, guardrail=guardrail + input_data=input_data, + guardrail=guardrail, + agent_input=_resolve_agent_input(state, input_schema), ) else: # POST_EXECUTION output_data = output_data_extractor(state) if output_data_extractor else {} @@ -150,6 +174,7 @@ def _create_guardrail_node( | None = None, tool_name: str | None = None, tool_type: str | None = None, + input_schema: type[BaseModel] | None = None, ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]: """Private factory for guardrail evaluation nodes. @@ -195,6 +220,7 @@ async def node( execution_stage, input_data_extractor, output_data_extractor, + input_schema, ) elif isinstance(guardrail, BuiltInValidatorGuardrail): # Generate and store payload for observability @@ -314,6 +340,7 @@ def create_tool_guardrail_node( failure_node: str, tool_name: str, tool_type: str | None = None, + input_schema: type[BaseModel] | None = None, ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]: """Create a guardrail node for TOOL scope guardrails. @@ -324,6 +351,8 @@ def create_tool_guardrail_node( failure_node: Node to route to on validation fail. tool_name: Name of the tool to extract arguments from. tool_type: Optional type of the tool (e.g., "process", "escalation", "mcp"). + input_schema: Optional agent input schema; enables rules to reference + ``FieldSource.AGENT_INPUT`` during pre-execution evaluation. Returns: A tuple of (node_name, node_function) for the guardrail evaluation node. @@ -375,4 +404,5 @@ def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]: _output_data_extractor, tool_name, tool_type, + input_schema, ) diff --git a/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py b/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py index edc0ad710..ef774e431 100644 --- a/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py +++ b/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py @@ -452,7 +452,10 @@ def create_tool_guardrails_subgraph( scope=GuardrailScope.TOOL, execution_stages=[ExecutionStage.PRE_EXECUTION, ExecutionStage.POST_EXECUTION], node_factory=partial( - create_tool_guardrail_node, tool_name=tool_name, tool_type=tool_type + create_tool_guardrail_node, + tool_name=tool_name, + tool_type=tool_type, + input_schema=input_schema, ), input_schema=input_schema, ) diff --git a/tests/agent/guardrails/test_guardrail_nodes.py b/tests/agent/guardrails/test_guardrail_nodes.py index 9c560f100..fa98ad129 100644 --- a/tests/agent/guardrails/test_guardrail_nodes.py +++ b/tests/agent/guardrails/test_guardrail_nodes.py @@ -498,7 +498,7 @@ async def test_evaluate_deterministic_guardrail_pre_execution(self, monkeypatch) assert result.result == GuardrailValidationResultType.PASSED mock_service.evaluate_pre_deterministic_guardrail.assert_called_once_with( - input_data={"test": "data"}, guardrail=guardrail + input_data={"test": "data"}, guardrail=guardrail, agent_input=None ) @pytest.mark.asyncio @@ -544,6 +544,110 @@ async def test_evaluate_deterministic_guardrail_post_execution(self, monkeypatch guardrail=guardrail, ) + @pytest.mark.asyncio + async def test_evaluate_deterministic_guardrail_pre_passes_agent_input_when_schema_supplied( + self, monkeypatch + ): + """When input_schema is supplied, the agent input dict is extracted from + state and forwarded to evaluate_pre_deterministic_guardrail.""" + from pydantic import BaseModel + from uipath.core.guardrails import DeterministicGuardrail + + from uipath_langchain.agent.guardrails.guardrail_nodes import ( + _evaluate_deterministic_guardrail, + ) + from uipath_langchain.agent.react.utils import ( + create_guardrails_state_with_input, + ) + + class AgentInputSchema(BaseModel): + user_identity: str + role: str + + mock_result = GuardrailValidationResult( + result=GuardrailValidationResultType.PASSED, reason="" + ) + mock_service = MagicMock() + mock_service.evaluate_pre_deterministic_guardrail.return_value = mock_result + monkeypatch.setattr( + "uipath_langchain.agent.guardrails.guardrail_nodes.DeterministicGuardrailsService", + lambda: mock_service, + ) + + guardrail = MagicMock(spec=DeterministicGuardrail) + state_cls = create_guardrails_state_with_input(AgentInputSchema) + state = state_cls(messages=[], user_identity="U157877", role="admin") + input_extractor = MagicMock(return_value={"target_user_id": "U157878"}) + output_extractor = MagicMock() + + result = _evaluate_deterministic_guardrail( + state, + guardrail, + ExecutionStage.PRE_EXECUTION, + input_extractor, + output_extractor, + AgentInputSchema, + ) + + assert result.result == GuardrailValidationResultType.PASSED + mock_service.evaluate_pre_deterministic_guardrail.assert_called_once_with( + input_data={"target_user_id": "U157878"}, + guardrail=guardrail, + agent_input={"user_identity": "U157877", "role": "admin"}, + ) + + @pytest.mark.asyncio + async def test_evaluate_deterministic_guardrail_post_does_not_pass_agent_input( + self, monkeypatch + ): + """Even with input_schema supplied, post-execution does NOT receive + agent_input — agent-input rules are pre-execution only by design.""" + from pydantic import BaseModel + from uipath.core.guardrails import DeterministicGuardrail + + from uipath_langchain.agent.guardrails.guardrail_nodes import ( + _evaluate_deterministic_guardrail, + ) + from uipath_langchain.agent.react.utils import ( + create_guardrails_state_with_input, + ) + + class AgentInputSchema(BaseModel): + user_identity: str + + mock_service = MagicMock() + mock_service.evaluate_post_deterministic_guardrail.return_value = ( + GuardrailValidationResult( + result=GuardrailValidationResultType.PASSED, reason="" + ) + ) + monkeypatch.setattr( + "uipath_langchain.agent.guardrails.guardrail_nodes.DeterministicGuardrailsService", + lambda: mock_service, + ) + + guardrail = MagicMock(spec=DeterministicGuardrail) + state_cls = create_guardrails_state_with_input(AgentInputSchema) + state = state_cls(messages=[], user_identity="U157877") + + _evaluate_deterministic_guardrail( + state, + guardrail, + ExecutionStage.POST_EXECUTION, + MagicMock(return_value={"input": "data"}), + MagicMock(return_value={"output": "data"}), + AgentInputSchema, + ) + + mock_service.evaluate_post_deterministic_guardrail.assert_called_once_with( + input_data={"input": "data"}, + output_data={"output": "data"}, + guardrail=guardrail, + ) + # agent_input must NOT be present in the post call + kwargs = mock_service.evaluate_post_deterministic_guardrail.call_args.kwargs + assert "agent_input" not in kwargs + @pytest.mark.asyncio async def test_evaluate_builtin_guardrail(self, monkeypatch): """Test built-in guardrail evaluation."""