diff --git a/agentic-rag-authorization/.env.example b/agentic-rag-authorization/.env.example deleted file mode 100644 index c2b3fe4..0000000 --- a/agentic-rag-authorization/.env.example +++ /dev/null @@ -1,17 +0,0 @@ -# Weaviate Configuration -WEAVIATE_URL=http://localhost:8080 -WEAVIATE_API_KEY= - -# SpiceDB Configuration -SPICEDB_ENDPOINT=localhost:50051 -SPICEDB_TOKEN=devtoken - -# OpenAI Configuration -# Get your API key from https://platform.openai.com/api-keys -OPENAI_API_KEY=your-api-key-here - -# Agent Behavior -MAX_RETRIEVAL_ATTEMPTS=1 - -# Logging -LOG_LEVEL=INFO diff --git a/agentic-rag-authorization/agentic_rag/authorization_helpers.py b/agentic-rag-authorization/agentic_rag/authorization_helpers.py deleted file mode 100644 index be7766c..0000000 --- a/agentic-rag-authorization/agentic_rag/authorization_helpers.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Helper functions for authorization operations.""" - -from typing import List, Tuple -from authzed.api.v1 import ( - CheckBulkPermissionsRequest, - CheckBulkPermissionsRequestItem, - ObjectReference, - SubjectReference, - Client, -) -from langchain_core.documents import Document -from .logging_config import get_logger - -logger = get_logger("authorization_helpers") - - -def batch_check_permissions( - client: Client, - subject_id: str, - documents: List[Document], -) -> Tuple[List[Document], List[str]]: - """ - Check permissions for multiple documents using SpiceDB's bulk API. - - Uses CheckBulkPermissions for efficient batch checking in a single request. - This is 5-10x faster than sequential individual permission checks. - - Args: - client: SpiceDB client - subject_id: User/subject ID to check permissions for - documents: List of documents to check permissions for - - Returns: - Tuple of (authorized_documents, denied_doc_ids) - """ - if not documents: - return [], [] - - logger.debug( - "Starting batch permission check", - extra={ - "subject_id": subject_id, - "document_count": len(documents), - }, - ) - - try: - # Build bulk request items - items = [] - for doc in documents: - doc_id = doc.metadata.get("doc_id") - items.append( - CheckBulkPermissionsRequestItem( - resource=ObjectReference(object_type="document", object_id=doc_id), - permission="view", - subject=SubjectReference( - object=ObjectReference(object_type="user", object_id=subject_id) - ), - ) - ) - - # Single bulk request to SpiceDB - request = CheckBulkPermissionsRequest(items=items) - response = client.CheckBulkPermissions(request) - - # Process results - authorized_docs = [] - denied_doc_ids = [] - - for i, pair in enumerate(response.pairs): - doc = documents[i] - doc_id = doc.metadata.get("doc_id") - - # Check if permission is granted - # permissionship: 0=UNSPECIFIED, 1=NO_PERMISSION, 2=HAS_PERMISSION - if pair.item.permissionship == 2: - authorized_docs.append(doc) - else: - denied_doc_ids.append(doc_id) - - logger.debug( - "Batch permission check complete", - extra={ - "subject_id": subject_id, - "authorized": len(authorized_docs), - "denied": len(denied_doc_ids), - }, - ) - - return authorized_docs, denied_doc_ids - - except Exception as e: - logger.error( - "Batch permission check failed", - extra={ - "subject_id": subject_id, - "error": str(e), - "error_type": type(e).__name__, - }, - exc_info=True, - ) - - # Fail closed - treat error as all denied (security-safe default) - denied_doc_ids = [doc.metadata.get("doc_id", "unknown") for doc in documents] - return [], denied_doc_ids diff --git a/agentic-rag-authorization/agentic_rag/grpc_helpers.py b/agentic-rag-authorization/agentic_rag/grpc_helpers.py index 2d73d80..5acb2e2 100644 --- a/agentic-rag-authorization/agentic_rag/grpc_helpers.py +++ b/agentic-rag-authorization/agentic_rag/grpc_helpers.py @@ -1,104 +1,33 @@ -"""Helper functions for gRPC and SpiceDB authentication.""" +"""Helper functions for SpiceDB client creation.""" -import grpc from threading import Lock from typing import Optional +from authzed.api.v1 import InsecureClient -class BearerTokenInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor): - """ - gRPC interceptor that adds bearer token to all requests. - - This is for local development with SpiceDB's --grpc-no-tls flag. - """ - - def __init__(self, token: str): - self._token = token - - def _add_authorization(self, client_call_details): - """Add authorization metadata to the call.""" - metadata = [] - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - metadata.append(("authorization", f"Bearer {self._token}")) - - return grpc._interceptor._ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - client_call_details.wait_for_ready, - client_call_details.compression, - ) - - def intercept_unary_unary(self, continuation, client_call_details, request): - """Intercept unary-unary calls.""" - new_details = self._add_authorization(client_call_details) - return continuation(new_details, request) - - def intercept_unary_stream(self, continuation, client_call_details, request): - """Intercept unary-stream calls.""" - new_details = self._add_authorization(client_call_details) - return continuation(new_details, request) - - -# Global singleton for SpiceDB client with thread-safe initialization -_spicedb_client: Optional["Client"] = None +_spicedb_client: Optional[InsecureClient] = None _spicedb_lock = Lock() -def create_insecure_spicedb_client(endpoint: str, token: str): +def create_insecure_spicedb_client(endpoint: str, token: str) -> InsecureClient: """ Create a SpiceDB client for insecure connections (local development). - This is for SpiceDB running with --grpc-no-tls flag. - - Args: - endpoint: The SpiceDB endpoint (e.g., "localhost:50051") - token: The bearer token (e.g., "devtoken") - - Returns: - authzed.api.v1.Client configured for insecure connection + For SpiceDB running with --grpc-no-tls flag. """ - from authzed.api.v1 import Client - - # Create insecure channel with bearer token interceptor - channel = grpc.insecure_channel(endpoint) - interceptor = BearerTokenInterceptor(token) - intercepted_channel = grpc.intercept_channel(channel, interceptor) - - # Create client bypassing __init__ and initialize with our channel - client = Client.__new__(Client) - client.init_stubs(intercepted_channel) + return InsecureClient(endpoint, token) - return client - -def get_spicedb_client(endpoint: str, token: str): +def get_spicedb_client(endpoint: str, token: str) -> InsecureClient: """ Get or create reusable SpiceDB client (singleton, thread-safe). - - This function provides connection pooling for SpiceDB by maintaining - a single client instance across requests, eliminating connection overhead. - - Args: - endpoint: The SpiceDB endpoint (e.g., "localhost:50051") - token: The bearer token (e.g., "devtoken") - - Returns: - authzed.api.v1.Client configured for insecure connection """ - from authzed.api.v1 import Client - global _spicedb_client - # Fast path: client already exists if _spicedb_client is not None: return _spicedb_client - # Slow path: create new client with thread-safe lock with _spicedb_lock: - # Double-check after acquiring lock if _spicedb_client is None: _spicedb_client = create_insecure_spicedb_client(endpoint, token) @@ -106,24 +35,7 @@ def get_spicedb_client(endpoint: str, token: str): def reset_spicedb_client(): - """ - Reset singleton (useful for testing). - - This allows tests to clear the cached client and create a fresh one. - """ + """Reset singleton (useful for testing).""" global _spicedb_client with _spicedb_lock: _spicedb_client = None - - -# Backward compatibility - keep the old function name -def insecure_bearer_token_credentials(token: str): - """ - Deprecated: Use create_insecure_spicedb_client instead. - - This function is kept for backward compatibility but doesn't work - with authzed Client for insecure connections. - """ - raise NotImplementedError( - "For insecure SpiceDB connections, use create_insecure_spicedb_client() instead" - ) diff --git a/agentic-rag-authorization/agentic_rag/nodes/authorization_node.py b/agentic-rag-authorization/agentic_rag/nodes/authorization_node.py index 29a33dc..1507985 100644 --- a/agentic-rag-authorization/agentic_rag/nodes/authorization_node.py +++ b/agentic-rag-authorization/agentic_rag/nodes/authorization_node.py @@ -1,65 +1,70 @@ """Authorization node - deterministic permission filtering via SpiceDB.""" from langchain_core.messages import SystemMessage +from langchain_spicedb.core import SpiceDBAuthorizer from ..state import AgenticRAGState from ..config import get_config -from ..grpc_helpers import get_spicedb_client from ..logging_config import get_logger -from ..authorization_helpers import batch_check_permissions -from ..node_helpers import log_node_execution logger = get_logger("nodes.authorization") +_authorizer: SpiceDBAuthorizer | None = None -def authorization_node(state: AgenticRAGState) -> dict: + +def _get_authorizer() -> SpiceDBAuthorizer: + global _authorizer + if _authorizer is None: + config = get_config() + _authorizer = SpiceDBAuthorizer( + spicedb_endpoint=config.spicedb_endpoint, + spicedb_token=config.spicedb_token, + resource_type="document", + subject_type="user", + permission="view", + resource_id_key="doc_id", + ) + return _authorizer + + +async def authorization_node(state: AgenticRAGState) -> dict: """ Deterministic authorization node - ALWAYS runs, cannot be bypassed. - This node filters retrieved documents based on SpiceDB permissions. + Filters retrieved documents through SpiceDB's CheckBulkPermissions API. This is a security boundary - the agent cannot bypass this check. """ - config = get_config() + authorizer = _get_authorizer() - with log_node_execution( - logger, - "authorization", - { + logger.info( + "Starting authorization", + extra={ "subject_id": state["subject_id"], "document_count": len(state["retrieved_documents"]), - } - ): - # Get or create SpiceDB client (reused across requests) - client = get_spicedb_client( - config.spicedb_endpoint, - config.spicedb_token, - ) - - # Batch check permissions using SpiceDB's bulk API - authorized_docs, denied_doc_ids = batch_check_permissions( - client, - state["subject_id"], - state["retrieved_documents"], - ) + }, + ) - denied_count = len(denied_doc_ids) + result = await authorizer.filter_documents( + documents=state["retrieved_documents"], + subject_id=state["subject_id"], + ) - logger.info( - "Authorization results", - extra={ - "authorized": len(authorized_docs), - "denied": denied_count, - "denied_doc_ids": denied_doc_ids, - }, - ) + logger.info( + "Authorization results", + extra={ + "authorized": result.total_authorized, + "denied": len(result.denied_resource_ids), + "denied_doc_ids": result.denied_resource_ids, + }, + ) - return { - "authorized_documents": authorized_docs, - "denied_count": denied_count, - "authorization_passed": len(authorized_docs) > 0, - "messages": [ - SystemMessage( - content=f"Authorization: {len(authorized_docs)}/{len(state['retrieved_documents'])} documents authorized" - ) - ], - } + return { + "authorized_documents": result.authorized_documents, + "denied_count": len(result.denied_resource_ids), + "authorization_passed": result.total_authorized > 0, + "messages": [ + SystemMessage( + content=f"Authorization: {result.total_authorized}/{result.total_retrieved} documents authorized" + ) + ], + } diff --git a/agentic-rag-authorization/requirements.txt b/agentic-rag-authorization/requirements.txt index 2eecaab..c40bc86 100644 --- a/agentic-rag-authorization/requirements.txt +++ b/agentic-rag-authorization/requirements.txt @@ -2,6 +2,7 @@ langchain>=0.1.0 langchain-openai>=0.1.0 langgraph>=0.0.20 +langchain-spicedb>=0.2.0 weaviate-client>=3.26.0,<4.0 # v3 for REST API stability (no gRPC issues) authzed>=0.7.0 python-dotenv>=1.0.0 diff --git a/agentic-rag-authorization/test_improvements.py b/agentic-rag-authorization/test_improvements.py index 751e162..9ba9dd7 100644 --- a/agentic-rag-authorization/test_improvements.py +++ b/agentic-rag-authorization/test_improvements.py @@ -86,27 +86,24 @@ def test_connection_pooling(): def test_batch_permissions(): - """Test that batch permission checker is defined.""" - print("\n=== Testing Batch Permission Checker ===") + """Test that langchain-spicedb SpiceDBAuthorizer is importable.""" + print("\n=== Testing Batch Permission Checker (langchain-spicedb) ===") try: - from agentic_rag.authorization_helpers import batch_check_permissions - - # Verify function signature + from langchain_spicedb.core import SpiceDBAuthorizer import inspect - sig = inspect.signature(batch_check_permissions) + sig = inspect.signature(SpiceDBAuthorizer.filter_documents) params = list(sig.parameters.keys()) - assert "client" in params assert "subject_id" in params assert "documents" in params - print("✅ Batch permission checker defined correctly") - print(f" Function signature: batch_check_permissions{sig}") + print("✅ SpiceDBAuthorizer.filter_documents available from langchain-spicedb") + print(f" Method signature: filter_documents{sig}") return True except (ImportError, AssertionError) as e: - print(f"❌ Batch permission checker failed: {e}") + print(f"❌ SpiceDBAuthorizer check failed: {e}") return False @@ -170,16 +167,16 @@ def test_error_handling(): assert "except Exception as e:" in retrieval_code assert "logger.error" in retrieval_code - # Check authorization helpers has try-except - with open("agentic_rag/authorization_helpers.py", "r") as f: + # Check authorization node has logging + with open("agentic_rag/nodes/authorization_node.py", "r") as f: auth_code = f.read() - assert "except Exception as e:" in auth_code - assert "logger.error" in auth_code + assert "logger.info" in auth_code + assert "SpiceDBAuthorizer" in auth_code print("✅ Error handling implemented correctly") print(" - Retrieval node has try-except") - print(" - Authorization helpers has try-except") + print(" - Authorization node uses langchain-spicedb SpiceDBAuthorizer") print(" - Errors logged with logger.error") return True except Exception as e: