diff --git a/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/ChatService.java b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/ChatService.java index 19f218a4..4a866021 100644 --- a/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/ChatService.java +++ b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/ChatService.java @@ -20,11 +20,14 @@ import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.ByteArrayResource; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.MediaTypeFactory; import org.springframework.stereotype.Service; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; import reactor.core.publisher.Flux; import tools.jackson.databind.json.JsonMapper; @@ -40,7 +43,6 @@ public boolean hasFile() { @Service public class ChatService { - private static final Logger logger = LoggerFactory.getLogger(ChatService.class); private final ChatClient chatClient; @@ -127,6 +129,10 @@ public ChatService(AgentCoreMemory agentCoreMemory, @AgentCoreInvocation public Flux chat(ChatRequest request, AgentCoreContext context) { + String authorization = context.getHeader(HttpHeaders.AUTHORIZATION); + RequestContextHolder.currentRequestAttributes() + .setAttribute(HttpHeaders.AUTHORIZATION, authorization, RequestAttributes.SCOPE_REQUEST); + if (request.hasFile()) { return processDocument(request.prompt(), request.fileBase64(), request.fileName()) .collectList() diff --git a/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/OAuthMcpConfig.java b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/OAuthMcpConfig.java new file mode 100644 index 00000000..6ba05247 --- /dev/null +++ b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/OAuthMcpConfig.java @@ -0,0 +1,44 @@ +package com.example.agent; + +import io.micrometer.context.ContextRegistry; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestAttributesThreadLocalAccessor; +import org.springframework.web.context.request.RequestContextHolder; + +@Configuration +public class OAuthMcpConfig { + private static final Logger logger = LoggerFactory.getLogger(OAuthMcpConfig.class); + + static { + ContextRegistry.getInstance().registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor()); + } + + @Bean + McpSyncHttpClientRequestCustomizer oauthRequestCustomizer() { + logger.info("OAuth token injection configured"); + + return (builder, method, endpoint, body, context) -> { + String auth = getAuthFromRequestContext(); + if (auth != null) { + logger.info("Authorization header propagated to MCP calls"); + builder.setHeader(HttpHeaders.AUTHORIZATION, auth); + } + }; + } + + private String getAuthFromRequestContext() { + try { + return (String) RequestContextHolder.currentRequestAttributes() + .getAttribute(HttpHeaders.AUTHORIZATION, RequestAttributes.SCOPE_REQUEST); + } catch (IllegalStateException e) { + logger.warn("Authorization header cannot be retrieved from local context: " + e.getMessage(), e); + return null; + } + } +} diff --git a/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/SigV4McpConfig.java b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/SigV4McpConfig.java index 6b37036c..87d56ce1 100644 --- a/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/SigV4McpConfig.java +++ b/apps/java-spring-ai-agents/aiagent/src/main/java/com/example/agent/SigV4McpConfig.java @@ -15,43 +15,44 @@ import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +// Deactivated in favor to OAuthMcpConfig, because policy can evaluate only JWT principle @Configuration public class SigV4McpConfig { - private static final Logger log = LoggerFactory.getLogger(SigV4McpConfig.class); - private static final Set RESTRICTED_HEADERS = Set.of("content-length", "host", "expect"); - - @Bean - McpSyncHttpClientRequestCustomizer sigV4RequestCustomizer() { - var signer = Aws4Signer.create(); - var credentialsProvider = DefaultCredentialsProvider.create(); - var region = new DefaultAwsRegionProviderChain().getRegion(); - log.info("SigV4 MCP request customizer: region={}, service=bedrock-agentcore", region); - - return (builder, method, endpoint, body, context) -> { - byte[] bodyBytes = (body != null) ? body.getBytes(java.nio.charset.StandardCharsets.UTF_8) : null; - - var sdkRequestBuilder = SdkHttpFullRequest.builder(); - sdkRequestBuilder.uri(endpoint); - sdkRequestBuilder.method(SdkHttpMethod.valueOf(method)); - - if (bodyBytes != null && bodyBytes.length > 0) { - sdkRequestBuilder.contentStreamProvider(() -> new ByteArrayInputStream(bodyBytes)); - sdkRequestBuilder.putHeader("Content-Length", String.valueOf(bodyBytes.length)); - } - sdkRequestBuilder.putHeader("Content-Type", "application/json"); - - var signedRequest = signer.sign(sdkRequestBuilder.build(), Aws4SignerParams.builder() - .signingName("bedrock-agentcore") - .signingRegion(region) - .awsCredentials(credentialsProvider.resolveCredentials()) - .build()); - - signedRequest.headers().forEach((name, values) -> { - if (!RESTRICTED_HEADERS.contains(name.toLowerCase())) { - values.forEach(value -> builder.setHeader(name, value)); - } - }); - }; - } +// private static final Logger log = LoggerFactory.getLogger(SigV4McpConfig.class); +// private static final Set RESTRICTED_HEADERS = Set.of("content-length", "host", "expect"); +// +// @Bean +// McpSyncHttpClientRequestCustomizer sigV4RequestCustomizer() { +// var signer = Aws4Signer.create(); +// var credentialsProvider = DefaultCredentialsProvider.create(); +// var region = new DefaultAwsRegionProviderChain().getRegion(); +// log.info("SigV4 MCP request customizer: region={}, service=bedrock-agentcore", region); +// +// return (builder, method, endpoint, body, context) -> { +// byte[] bodyBytes = (body != null) ? body.getBytes(java.nio.charset.StandardCharsets.UTF_8) : null; +// +// var sdkRequestBuilder = SdkHttpFullRequest.builder(); +// sdkRequestBuilder.uri(endpoint); +// sdkRequestBuilder.method(SdkHttpMethod.valueOf(method)); +// +// if (bodyBytes != null && bodyBytes.length > 0) { +// sdkRequestBuilder.contentStreamProvider(() -> new ByteArrayInputStream(bodyBytes)); +// sdkRequestBuilder.putHeader("Content-Length", String.valueOf(bodyBytes.length)); +// } +// sdkRequestBuilder.putHeader("Content-Type", "application/json"); +// +// var signedRequest = signer.sign(sdkRequestBuilder.build(), Aws4SignerParams.builder() +// .signingName("bedrock-agentcore") +// .signingRegion(region) +// .awsCredentials(credentialsProvider.resolveCredentials()) +// .build()); +// +// signedRequest.headers().forEach((name, values) -> { +// if (!RESTRICTED_HEADERS.contains(name.toLowerCase())) { +// values.forEach(value -> builder.setHeader(name, value)); +// } +// }); +// }; +// } } diff --git a/apps/java-spring-ai-agents/aiagent/src/main/resources/application.properties b/apps/java-spring-ai-agents/aiagent/src/main/resources/application.properties index 6d3525b7..c0dc5432 100644 --- a/apps/java-spring-ai-agents/aiagent/src/main/resources/application.properties +++ b/apps/java-spring-ai-agents/aiagent/src/main/resources/application.properties @@ -20,3 +20,5 @@ agentcore.browser.screenshot-description=Take a screenshot of a web page for the # MCP Client spring.ai.mcp.client.toolcallback.enabled=true spring.ai.mcp.client.initialized=false +# Local thread variables propagation +spring.reactor.context-propagation=auto diff --git a/apps/java-spring-ai-agents/scripts/policy/01-create-policy-engine.py b/apps/java-spring-ai-agents/scripts/policy/01-create-policy-engine.py new file mode 100644 index 00000000..8f03c72a --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/01-create-policy-engine.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +""" +Step 1: Create Policy Engine + +Run: python 01-create-policy-engine.py +Auto-updates ENGINE_ID in .env +""" + +from config import update_env +from policy_commands import create_policy_engine, list_policy_engines + +ENGINE_NAME = "TravelPolicyEngine" + +if __name__ == "__main__": + print("=== Creating Policy Engine ===") + engine = create_policy_engine(ENGINE_NAME) + + # Auto-update .env + update_env("ENGINE_ID", engine['policyEngineId']) + + print("\n=== All Policy Engines ===") + list_policy_engines() diff --git a/apps/java-spring-ai-agents/scripts/policy/02-create-policy.py b/apps/java-spring-ai-agents/scripts/policy/02-create-policy.py new file mode 100644 index 00000000..009fb930 --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/02-create-policy.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +""" +Step 2: Create Policy + +Run: python 02-create-policy.py +Requires: ENGINE_ID set in config.py +""" + +from config import ENGINE_ID, GATEWAY_ARN, TARGET_NAME +from policy_commands import create_policy, list_policies, delete_all_policies +import time + +# Policy: Permit all tools for alice, EXCEPT searchFlights +POLICY = f'''permit( + principal is AgentCore::OAuthUser, + action, + resource == AgentCore::Gateway::"{GATEWAY_ARN}" +) when {{ + principal.hasTag("username") && + principal.getTag("username") == "alice" +}} unless {{ + action == AgentCore::Action::"{TARGET_NAME}___searchFlights" +}};''' + +if __name__ == "__main__": + print(f"Using ENGINE_ID: {ENGINE_ID}") + print(f"Using GATEWAY_ARN: {GATEWAY_ARN}") + + print("\n=== Deleting existing policies ===") + delete_all_policies(ENGINE_ID) + time.sleep(3) + + print("\n=== Creating policy ===") + create_policy(ENGINE_ID, "PermitAllExceptFlights", POLICY) + + time.sleep(5) + + print("\n=== Policy status ===") + list_policies(ENGINE_ID) diff --git a/apps/java-spring-ai-agents/scripts/policy/03-attach-policy-engine.py b/apps/java-spring-ai-agents/scripts/policy/03-attach-policy-engine.py new file mode 100644 index 00000000..fac1b68d --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/03-attach-policy-engine.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +""" +Step 3: Attach Policy Engine to Gateway + +Run: python 03-attach-policy-engine.py +Requires: ENGINE_ID and GATEWAY_ID set in config.py +""" + +from config import GATEWAY_ID, ENGINE_ARN +from policy_commands import attach_policy_engine, get_gateway_policy_config + +if __name__ == "__main__": + import time + + print(f"Using GATEWAY_ID: {GATEWAY_ID}") + print(f"Using ENGINE_ARN: {ENGINE_ARN}") + + print("\n=== Attaching Policy Engine ===") + attach_policy_engine(GATEWAY_ID, ENGINE_ARN, mode="ENFORCE") + + print("\n=== Waiting for attachment ===") + for i in range(10): + time.sleep(2) + config = get_gateway_policy_config(GATEWAY_ID) + if config: + print("✓ Policy engine attached") + break + print(f" Attempt {i+1}/10...") + else: + print("✗ Attachment not confirmed after 10 attempts") diff --git a/apps/java-spring-ai-agents/scripts/policy/README.md b/apps/java-spring-ai-agents/scripts/policy/README.md new file mode 100644 index 00000000..a7070886 --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/README.md @@ -0,0 +1,84 @@ +# AgentCore Policy-Based Access Control + +Demonstrates Cedar policy-based access control for MCP tools through AgentCore Gateway. + +## Overview + +``` +chat-agent (JWT) → Gateway → Policy Engine → MCP Runtime + ↓ + Cedar policy checks + user's username tag +``` + +## Key Finding: Use `unless` for Denying Tools + +AgentCore rejects standalone `forbid` policies as "Overly Restrictive". + +**Solution:** Use `permit ... unless` to deny specific tools: + +```cedar +permit( + principal is AgentCore::OAuthUser, + action, + resource == AgentCore::Gateway::"..." +) when { + principal.hasTag("username") && + principal.getTag("username") == "alice" +} unless { + action == AgentCore::Action::"travel-mcp___searchFlights" +}; +``` + +## Setup Steps + +```bash +cd /Users/shakirin/Projects/agentcore/samples/policy/scripts/policy +source .venv/bin/activate + +# 1. Create policy engine (one-time) +python 01-create-policy-engine.py + +# 2. Create policy (update ENGINE_ID in script first) +python 02-create-policy.py + +# 3. Attach to gateway (update IDs in script first) +python 03-attach-policy-engine.py +``` + +## Test Results + +| Tool | Policy | Result | +|------|--------|--------| +| `searchHotels` | Permitted | ✅ Returns hotel data | +| `searchFlights` | Denied via `unless` | ❌ Tool not available | + +## JWT Token Requirements + +The policy uses `username` tag from Cognito user tokens: + +```json +{ + "username": "alice", + "client_id": "...", + "token_use": "access" +} +``` + +Get user token: +```bash +TOKEN=$(aws cognito-idp initiate-auth \ + --client-id $CLIENT_ID \ + --auth-flow USER_PASSWORD_AUTH \ + --auth-parameters "USERNAME=alice,PASSWORD=$PASSWORD,SECRET_HASH=$SECRET_HASH" \ + --region us-east-1 \ + --query 'AuthenticationResult.AccessToken' --output text) +``` + +## Files + +- `policy.cedar` - Working Cedar policy with `unless` clause +- `policy_commands.py` - Helper functions for policy management +- `01-create-policy-engine.py` - Create policy engine +- `02-create-policy.py` - Create/update policy +- `03-attach-policy-engine.py` - Attach engine to gateway diff --git a/apps/java-spring-ai-agents/scripts/policy/config.py b/apps/java-spring-ai-agents/scripts/policy/config.py new file mode 100644 index 00000000..290e0f3b --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/config.py @@ -0,0 +1,39 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + +REGION = os.getenv("REGION", "us-east-1") +ACCOUNT_ID = os.getenv("ACCOUNT_ID") +GATEWAY_ID = os.getenv("GATEWAY_ID") +ENGINE_ID = os.getenv("ENGINE_ID") +TARGET_NAME = os.getenv("TARGET_NAME", "travel-mcp") + +# Derived ARNs +GATEWAY_ARN = f"arn:aws:bedrock-agentcore:{REGION}:{ACCOUNT_ID}:gateway/{GATEWAY_ID}" +ENGINE_ARN = f"arn:aws:bedrock-agentcore:{REGION}:{ACCOUNT_ID}:policy-engine/{ENGINE_ID}" if ENGINE_ID else None + + +def update_env(key: str, value: str): + """Update a value in .env file.""" + env_path = os.path.join(os.path.dirname(__file__), ".env") + + with open(env_path, "r") as f: + lines = f.readlines() + + updated = False + for i, line in enumerate(lines): + if line.startswith(f"{key}="): + lines[i] = f"{key}={value}\n" + updated = True + break + + if not updated: + lines.append(f"{key}={value}\n") + + with open(env_path, "w") as f: + f.writelines(lines) + + # Update current process + os.environ[key] = value + print(f"Updated .env: {key}={value}") diff --git a/apps/java-spring-ai-agents/scripts/policy/deny-tools-policy.cedar b/apps/java-spring-ai-agents/scripts/policy/deny-tools-policy.cedar new file mode 100644 index 00000000..f6930c08 --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/deny-tools-policy.cedar @@ -0,0 +1,26 @@ +# Cedar policy: Permit all tools EXCEPT searchFlights for user alice +# +# Uses 'unless' clause to deny specific tools while permitting all others. +# This approach works around AgentCore's "Overly Restrictive" safety check +# that rejects standalone 'forbid' policies. +# +# NOTE: This is a template. Actual policy is generated by 02-create-policy.py +# using values from .env (GATEWAY_ID, TARGET_NAME) + +permit( + principal is AgentCore::OAuthUser, + action, + resource == AgentCore::Gateway::"arn:aws:bedrock-agentcore:${REGION}:${ACCOUNT_ID}:gateway/${GATEWAY_ID}" +) when { + principal.hasTag("username") && + principal.getTag("username") == "alice" +} unless { + action == AgentCore::Action::"${TARGET_NAME}___cancelTrip" || + action == AgentCore::Action::"${TARGET_NAME}___deleteExpense" +}; + +# To deny multiple tools, use OR: +# } unless { +# action == AgentCore::Action::"${TARGET_NAME}___searchFlights" || +# action == AgentCore::Action::"${TARGET_NAME}___anotherTool" +# }; diff --git a/apps/java-spring-ai-agents/scripts/policy/policy_commands.py b/apps/java-spring-ai-agents/scripts/policy/policy_commands.py new file mode 100644 index 00000000..097bec88 --- /dev/null +++ b/apps/java-spring-ai-agents/scripts/policy/policy_commands.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +AgentCore Policy Management Commands +""" + +from bedrock_agentcore_starter_toolkit.operations.policy.client import PolicyClient +import boto3 +from config import REGION + + +def get_policy_client(): + return PolicyClient(region_name=REGION) + + +def get_control_client(): + return boto3.client('bedrock-agentcore-control', region_name=REGION) + + +# Policy Engine operations +def create_policy_engine(name: str): + client = get_policy_client() + engine = client.create_policy_engine(name=name) + print(f"Created: {engine['policyEngineId']}") + print(f"ARN: {engine['policyEngineArn']}") + return engine + + +def list_policy_engines(): + client = get_policy_client() + for e in client.list_policy_engines().get('policyEngines', []): + print(f"{e['name']}: {e['policyEngineId']} - {e['status']}") + + +def delete_policy_engine(engine_id: str): + client = get_policy_client() + client.delete_policy_engine(engine_id) + print(f"Deleted: {engine_id}") + + +# Policy operations +def create_policy(engine_id: str, name: str, cedar_statement: str): + client = get_policy_client() + policy = client.create_policy( + policy_engine_id=engine_id, + name=name, + definition={"cedar": {"statement": cedar_statement}} + ) + print(f"Created: {policy['policyId']} - Status: {policy['status']}") + return policy + + +def list_policies(engine_id: str): + client = get_policy_client() + for p in client.list_policies(engine_id).get('policies', []): + print(f"{p['name']}: {p['status']}") + for r in p.get('statusReasons', []): + print(f" - {r}") + + +def delete_all_policies(engine_id: str): + client = get_policy_client() + for p in client.list_policies(engine_id).get('policies', []): + client.delete_policy(engine_id, p['policyId']) + print(f"Deleted: {p['name']}") + + +# Gateway attachment +def attach_policy_engine(gateway_id: str, engine_arn: str, mode: str = "ENFORCE"): + client = get_control_client() + gw = client.get_gateway(gatewayIdentifier=gateway_id) + client.update_gateway( + gatewayIdentifier=gateway_id, + name=gw['name'], + roleArn=gw['roleArn'], + protocolType=gw['protocolType'], + authorizerType=gw['authorizerType'], + policyEngineConfiguration={"arn": engine_arn, "mode": mode} + ) + print(f"Attached {engine_arn} to {gateway_id} in {mode} mode") + + +def get_gateway_policy_config(gateway_id: str): + client = get_control_client() + config = client.get_gateway(gatewayIdentifier=gateway_id).get('policyEngineConfiguration') + if config: + print(f"Policy Engine: {config['arn']}") + print(f"Mode: {config['mode']}") + else: + print("No policy engine attached") + return config