Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/uipath_langchain/agent/guardrails/guardrail_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable

from langgraph.types import Command
from pydantic import BaseModel
from uipath.core.guardrails import (
DeterministicGuardrail,
DeterministicGuardrailsService,
Expand All @@ -26,18 +27,35 @@
get_message_content,
)
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
from uipath_langchain.agent.react.utils import extract_input_data_from_state

from ..exceptions import AgentRuntimeError, AgentRuntimeErrorCode

logger = logging.getLogger(__name__)


def _resolve_agent_input(
state: AgentGuardrailsGraphState,
input_schema: type[BaseModel] | None,
) -> dict[str, Any] | None:
if input_schema is None:
return None
try:
return extract_input_data_from_state(state, input_schema)
except Exception:
# The state may not yet carry agent-input fields (e.g., very early
# subgraphs, or schemas whose required fields aren't seeded yet); fall
# back to "no agent_input available" rather than crashing the run.
return None


def _evaluate_deterministic_guardrail(
state: AgentGuardrailsGraphState,
guardrail: DeterministicGuardrail,
execution_stage: ExecutionStage,
input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]],
output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None,
input_schema: type[BaseModel] | None = None,
):
"""Evaluate deterministic guardrail.

Expand All @@ -47,6 +65,10 @@ def _evaluate_deterministic_guardrail(
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
input_data_extractor: Function to extract input data from state.
output_data_extractor: Function to extract output data from state (optional).
input_schema: Optional input schema; when provided, the agent's
validated input parameters are extracted from state and passed to
pre-execution evaluation so rules can reference
``FieldSource.AGENT_INPUT``.

Returns:
The guardrail evaluation result.
Expand All @@ -56,7 +78,9 @@ def _evaluate_deterministic_guardrail(

if execution_stage == ExecutionStage.PRE_EXECUTION:
return service.evaluate_pre_deterministic_guardrail(
input_data=input_data, guardrail=guardrail
input_data=input_data,
guardrail=guardrail,
agent_input=_resolve_agent_input(state, input_schema),
)
else: # POST_EXECUTION
output_data = output_data_extractor(state) if output_data_extractor else {}
Expand Down Expand Up @@ -150,6 +174,7 @@ def _create_guardrail_node(
| None = None,
tool_name: str | None = None,
tool_type: str | None = None,
input_schema: type[BaseModel] | None = None,
) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
"""Private factory for guardrail evaluation nodes.

Expand Down Expand Up @@ -195,6 +220,7 @@ async def node(
execution_stage,
input_data_extractor,
output_data_extractor,
input_schema,
)
elif isinstance(guardrail, BuiltInValidatorGuardrail):
# Generate and store payload for observability
Expand Down Expand Up @@ -314,6 +340,7 @@ def create_tool_guardrail_node(
failure_node: str,
tool_name: str,
tool_type: str | None = None,
input_schema: type[BaseModel] | None = None,
) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
"""Create a guardrail node for TOOL scope guardrails.

Expand All @@ -324,6 +351,8 @@ def create_tool_guardrail_node(
failure_node: Node to route to on validation fail.
tool_name: Name of the tool to extract arguments from.
tool_type: Optional type of the tool (e.g., "process", "escalation", "mcp").
input_schema: Optional agent input schema; enables rules to reference
``FieldSource.AGENT_INPUT`` during pre-execution evaluation.

Returns:
A tuple of (node_name, node_function) for the guardrail evaluation node.
Expand Down Expand Up @@ -375,4 +404,5 @@ def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
_output_data_extractor,
tool_name,
tool_type,
input_schema,
)
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def create_tool_guardrails_subgraph(
scope=GuardrailScope.TOOL,
execution_stages=[ExecutionStage.PRE_EXECUTION, ExecutionStage.POST_EXECUTION],
node_factory=partial(
create_tool_guardrail_node, tool_name=tool_name, tool_type=tool_type
create_tool_guardrail_node,
tool_name=tool_name,
tool_type=tool_type,
input_schema=input_schema,
),
input_schema=input_schema,
)
106 changes: 105 additions & 1 deletion tests/agent/guardrails/test_guardrail_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ async def test_evaluate_deterministic_guardrail_pre_execution(self, monkeypatch)

assert result.result == GuardrailValidationResultType.PASSED
mock_service.evaluate_pre_deterministic_guardrail.assert_called_once_with(
input_data={"test": "data"}, guardrail=guardrail
input_data={"test": "data"}, guardrail=guardrail, agent_input=None
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -544,6 +544,110 @@ async def test_evaluate_deterministic_guardrail_post_execution(self, monkeypatch
guardrail=guardrail,
)

@pytest.mark.asyncio
async def test_evaluate_deterministic_guardrail_pre_passes_agent_input_when_schema_supplied(
self, monkeypatch
):
"""When input_schema is supplied, the agent input dict is extracted from
state and forwarded to evaluate_pre_deterministic_guardrail."""
from pydantic import BaseModel
from uipath.core.guardrails import DeterministicGuardrail

from uipath_langchain.agent.guardrails.guardrail_nodes import (
_evaluate_deterministic_guardrail,
)
from uipath_langchain.agent.react.utils import (
create_guardrails_state_with_input,
)

class AgentInputSchema(BaseModel):
user_identity: str
role: str

mock_result = GuardrailValidationResult(
result=GuardrailValidationResultType.PASSED, reason=""
)
mock_service = MagicMock()
mock_service.evaluate_pre_deterministic_guardrail.return_value = mock_result
monkeypatch.setattr(
"uipath_langchain.agent.guardrails.guardrail_nodes.DeterministicGuardrailsService",
lambda: mock_service,
)

guardrail = MagicMock(spec=DeterministicGuardrail)
state_cls = create_guardrails_state_with_input(AgentInputSchema)
state = state_cls(messages=[], user_identity="U157877", role="admin")
input_extractor = MagicMock(return_value={"target_user_id": "U157878"})
output_extractor = MagicMock()

result = _evaluate_deterministic_guardrail(
state,
guardrail,
ExecutionStage.PRE_EXECUTION,
input_extractor,
output_extractor,
AgentInputSchema,
)

assert result.result == GuardrailValidationResultType.PASSED
mock_service.evaluate_pre_deterministic_guardrail.assert_called_once_with(
input_data={"target_user_id": "U157878"},
guardrail=guardrail,
agent_input={"user_identity": "U157877", "role": "admin"},
)

@pytest.mark.asyncio
async def test_evaluate_deterministic_guardrail_post_does_not_pass_agent_input(
self, monkeypatch
):
"""Even with input_schema supplied, post-execution does NOT receive
agent_input — agent-input rules are pre-execution only by design."""
from pydantic import BaseModel
from uipath.core.guardrails import DeterministicGuardrail

from uipath_langchain.agent.guardrails.guardrail_nodes import (
_evaluate_deterministic_guardrail,
)
from uipath_langchain.agent.react.utils import (
create_guardrails_state_with_input,
)

class AgentInputSchema(BaseModel):
user_identity: str

mock_service = MagicMock()
mock_service.evaluate_post_deterministic_guardrail.return_value = (
GuardrailValidationResult(
result=GuardrailValidationResultType.PASSED, reason=""
)
)
monkeypatch.setattr(
"uipath_langchain.agent.guardrails.guardrail_nodes.DeterministicGuardrailsService",
lambda: mock_service,
)

guardrail = MagicMock(spec=DeterministicGuardrail)
state_cls = create_guardrails_state_with_input(AgentInputSchema)
state = state_cls(messages=[], user_identity="U157877")

_evaluate_deterministic_guardrail(
state,
guardrail,
ExecutionStage.POST_EXECUTION,
MagicMock(return_value={"input": "data"}),
MagicMock(return_value={"output": "data"}),
AgentInputSchema,
)

mock_service.evaluate_post_deterministic_guardrail.assert_called_once_with(
input_data={"input": "data"},
output_data={"output": "data"},
guardrail=guardrail,
)
# agent_input must NOT be present in the post call
kwargs = mock_service.evaluate_post_deterministic_guardrail.call_args.kwargs
assert "agent_input" not in kwargs

@pytest.mark.asyncio
async def test_evaluate_builtin_guardrail(self, monkeypatch):
"""Test built-in guardrail evaluation."""
Expand Down
Loading