Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.MediaTypeFactory;
import org.springframework.stereotype.Service;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import reactor.core.publisher.Flux;
import tools.jackson.databind.json.JsonMapper;

Expand All @@ -40,7 +43,6 @@ public boolean hasFile() {

@Service
public class ChatService {

private static final Logger logger = LoggerFactory.getLogger(ChatService.class);

private final ChatClient chatClient;
Expand Down Expand Up @@ -127,6 +129,10 @@ public ChatService(AgentCoreMemory agentCoreMemory,

@AgentCoreInvocation
public Flux<String> chat(ChatRequest request, AgentCoreContext context) {
String authorization = context.getHeader(HttpHeaders.AUTHORIZATION);
RequestContextHolder.currentRequestAttributes()
.setAttribute(HttpHeaders.AUTHORIZATION, authorization, RequestAttributes.SCOPE_REQUEST);

if (request.hasFile()) {
return processDocument(request.prompt(), request.fileBase64(), request.fileName())
.collectList()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.example.agent;

import io.micrometer.context.ContextRegistry;
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestAttributesThreadLocalAccessor;
import org.springframework.web.context.request.RequestContextHolder;

@Configuration
public class OAuthMcpConfig {
private static final Logger logger = LoggerFactory.getLogger(OAuthMcpConfig.class);

static {
ContextRegistry.getInstance().registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor());
}

@Bean
McpSyncHttpClientRequestCustomizer oauthRequestCustomizer() {
logger.info("OAuth token injection configured");

return (builder, method, endpoint, body, context) -> {
String auth = getAuthFromRequestContext();
if (auth != null) {
logger.info("Authorization header propagated to MCP calls");
builder.setHeader(HttpHeaders.AUTHORIZATION, auth);
}
};
}

private String getAuthFromRequestContext() {
try {
return (String) RequestContextHolder.currentRequestAttributes()
.getAttribute(HttpHeaders.AUTHORIZATION, RequestAttributes.SCOPE_REQUEST);
} catch (IllegalStateException e) {
logger.warn("Authorization header cannot be retrieved from local context: " + e.getMessage(), e);
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,44 @@
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;

// Deactivated in favor to OAuthMcpConfig, because policy can evaluate only JWT principle
@Configuration
public class SigV4McpConfig {

private static final Logger log = LoggerFactory.getLogger(SigV4McpConfig.class);
private static final Set<String> RESTRICTED_HEADERS = Set.of("content-length", "host", "expect");

@Bean
McpSyncHttpClientRequestCustomizer sigV4RequestCustomizer() {
var signer = Aws4Signer.create();
var credentialsProvider = DefaultCredentialsProvider.create();
var region = new DefaultAwsRegionProviderChain().getRegion();
log.info("SigV4 MCP request customizer: region={}, service=bedrock-agentcore", region);

return (builder, method, endpoint, body, context) -> {
byte[] bodyBytes = (body != null) ? body.getBytes(java.nio.charset.StandardCharsets.UTF_8) : null;

var sdkRequestBuilder = SdkHttpFullRequest.builder();
sdkRequestBuilder.uri(endpoint);
sdkRequestBuilder.method(SdkHttpMethod.valueOf(method));

if (bodyBytes != null && bodyBytes.length > 0) {
sdkRequestBuilder.contentStreamProvider(() -> new ByteArrayInputStream(bodyBytes));
sdkRequestBuilder.putHeader("Content-Length", String.valueOf(bodyBytes.length));
}
sdkRequestBuilder.putHeader("Content-Type", "application/json");

var signedRequest = signer.sign(sdkRequestBuilder.build(), Aws4SignerParams.builder()
.signingName("bedrock-agentcore")
.signingRegion(region)
.awsCredentials(credentialsProvider.resolveCredentials())
.build());

signedRequest.headers().forEach((name, values) -> {
if (!RESTRICTED_HEADERS.contains(name.toLowerCase())) {
values.forEach(value -> builder.setHeader(name, value));
}
});
};
}
// private static final Logger log = LoggerFactory.getLogger(SigV4McpConfig.class);
// private static final Set<String> RESTRICTED_HEADERS = Set.of("content-length", "host", "expect");
//
// @Bean
// McpSyncHttpClientRequestCustomizer sigV4RequestCustomizer() {
// var signer = Aws4Signer.create();
// var credentialsProvider = DefaultCredentialsProvider.create();
// var region = new DefaultAwsRegionProviderChain().getRegion();
// log.info("SigV4 MCP request customizer: region={}, service=bedrock-agentcore", region);
//
// return (builder, method, endpoint, body, context) -> {
// byte[] bodyBytes = (body != null) ? body.getBytes(java.nio.charset.StandardCharsets.UTF_8) : null;
//
// var sdkRequestBuilder = SdkHttpFullRequest.builder();
// sdkRequestBuilder.uri(endpoint);
// sdkRequestBuilder.method(SdkHttpMethod.valueOf(method));
//
// if (bodyBytes != null && bodyBytes.length > 0) {
// sdkRequestBuilder.contentStreamProvider(() -> new ByteArrayInputStream(bodyBytes));
// sdkRequestBuilder.putHeader("Content-Length", String.valueOf(bodyBytes.length));
// }
// sdkRequestBuilder.putHeader("Content-Type", "application/json");
//
// var signedRequest = signer.sign(sdkRequestBuilder.build(), Aws4SignerParams.builder()
// .signingName("bedrock-agentcore")
// .signingRegion(region)
// .awsCredentials(credentialsProvider.resolveCredentials())
// .build());
//
// signedRequest.headers().forEach((name, values) -> {
// if (!RESTRICTED_HEADERS.contains(name.toLowerCase())) {
// values.forEach(value -> builder.setHeader(name, value));
// }
// });
// };
// }
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ agentcore.browser.screenshot-description=Take a screenshot of a web page for the
# MCP Client
spring.ai.mcp.client.toolcallback.enabled=true
spring.ai.mcp.client.initialized=false
# Local thread variables propagation
spring.reactor.context-propagation=auto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python3
"""
Step 1: Create Policy Engine

Run: python 01-create-policy-engine.py
Auto-updates ENGINE_ID in .env
"""

from config import update_env
from policy_commands import create_policy_engine, list_policy_engines

ENGINE_NAME = "TravelPolicyEngine"

if __name__ == "__main__":
print("=== Creating Policy Engine ===")
engine = create_policy_engine(ENGINE_NAME)

# Auto-update .env
update_env("ENGINE_ID", engine['policyEngineId'])

print("\n=== All Policy Engines ===")
list_policy_engines()
39 changes: 39 additions & 0 deletions apps/java-spring-ai-agents/scripts/policy/02-create-policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""
Step 2: Create Policy

Run: python 02-create-policy.py
Requires: ENGINE_ID set in config.py
"""

from config import ENGINE_ID, GATEWAY_ARN, TARGET_NAME
from policy_commands import create_policy, list_policies, delete_all_policies
import time

# Policy: Permit all tools for alice, EXCEPT searchFlights
POLICY = f'''permit(
principal is AgentCore::OAuthUser,
action,
resource == AgentCore::Gateway::"{GATEWAY_ARN}"
) when {{
principal.hasTag("username") &&
principal.getTag("username") == "alice"
}} unless {{
action == AgentCore::Action::"{TARGET_NAME}___searchFlights"
}};'''

if __name__ == "__main__":
print(f"Using ENGINE_ID: {ENGINE_ID}")
print(f"Using GATEWAY_ARN: {GATEWAY_ARN}")

print("\n=== Deleting existing policies ===")
delete_all_policies(ENGINE_ID)
time.sleep(3)

print("\n=== Creating policy ===")
create_policy(ENGINE_ID, "PermitAllExceptFlights", POLICY)

time.sleep(5)

print("\n=== Policy status ===")
list_policies(ENGINE_ID)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
"""
Step 3: Attach Policy Engine to Gateway

Run: python 03-attach-policy-engine.py
Requires: ENGINE_ID and GATEWAY_ID set in config.py
"""

from config import GATEWAY_ID, ENGINE_ARN
from policy_commands import attach_policy_engine, get_gateway_policy_config

if __name__ == "__main__":
import time

print(f"Using GATEWAY_ID: {GATEWAY_ID}")
print(f"Using ENGINE_ARN: {ENGINE_ARN}")

print("\n=== Attaching Policy Engine ===")
attach_policy_engine(GATEWAY_ID, ENGINE_ARN, mode="ENFORCE")

print("\n=== Waiting for attachment ===")
for i in range(10):
time.sleep(2)
config = get_gateway_policy_config(GATEWAY_ID)
if config:
print("✓ Policy engine attached")
break
print(f" Attempt {i+1}/10...")
else:
print("✗ Attachment not confirmed after 10 attempts")
84 changes: 84 additions & 0 deletions apps/java-spring-ai-agents/scripts/policy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# AgentCore Policy-Based Access Control

Demonstrates Cedar policy-based access control for MCP tools through AgentCore Gateway.

## Overview

```
chat-agent (JWT) → Gateway → Policy Engine → MCP Runtime
Cedar policy checks
user's username tag
```

## Key Finding: Use `unless` for Denying Tools

AgentCore rejects standalone `forbid` policies as "Overly Restrictive".

**Solution:** Use `permit ... unless` to deny specific tools:

```cedar
permit(
principal is AgentCore::OAuthUser,
action,
resource == AgentCore::Gateway::"..."
) when {
principal.hasTag("username") &&
principal.getTag("username") == "alice"
} unless {
action == AgentCore::Action::"travel-mcp___searchFlights"
};
```

## Setup Steps

```bash
cd /Users/shakirin/Projects/agentcore/samples/policy/scripts/policy
source .venv/bin/activate

# 1. Create policy engine (one-time)
python 01-create-policy-engine.py

# 2. Create policy (update ENGINE_ID in script first)
python 02-create-policy.py

# 3. Attach to gateway (update IDs in script first)
python 03-attach-policy-engine.py
```

## Test Results

| Tool | Policy | Result |
|------|--------|--------|
| `searchHotels` | Permitted | ✅ Returns hotel data |
| `searchFlights` | Denied via `unless` | ❌ Tool not available |

## JWT Token Requirements

The policy uses `username` tag from Cognito user tokens:

```json
{
"username": "alice",
"client_id": "...",
"token_use": "access"
}
```

Get user token:
```bash
TOKEN=$(aws cognito-idp initiate-auth \
--client-id $CLIENT_ID \
--auth-flow USER_PASSWORD_AUTH \
--auth-parameters "USERNAME=alice,PASSWORD=$PASSWORD,SECRET_HASH=$SECRET_HASH" \
--region us-east-1 \
--query 'AuthenticationResult.AccessToken' --output text)
```

## Files

- `policy.cedar` - Working Cedar policy with `unless` clause
- `policy_commands.py` - Helper functions for policy management
- `01-create-policy-engine.py` - Create policy engine
- `02-create-policy.py` - Create/update policy
- `03-attach-policy-engine.py` - Attach engine to gateway
39 changes: 39 additions & 0 deletions apps/java-spring-ai-agents/scripts/policy/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from dotenv import load_dotenv

load_dotenv()

REGION = os.getenv("REGION", "us-east-1")
ACCOUNT_ID = os.getenv("ACCOUNT_ID")
GATEWAY_ID = os.getenv("GATEWAY_ID")
ENGINE_ID = os.getenv("ENGINE_ID")
TARGET_NAME = os.getenv("TARGET_NAME", "travel-mcp")

# Derived ARNs
GATEWAY_ARN = f"arn:aws:bedrock-agentcore:{REGION}:{ACCOUNT_ID}:gateway/{GATEWAY_ID}"
ENGINE_ARN = f"arn:aws:bedrock-agentcore:{REGION}:{ACCOUNT_ID}:policy-engine/{ENGINE_ID}" if ENGINE_ID else None


def update_env(key: str, value: str):
"""Update a value in .env file."""
env_path = os.path.join(os.path.dirname(__file__), ".env")

with open(env_path, "r") as f:
lines = f.readlines()

updated = False
for i, line in enumerate(lines):
if line.startswith(f"{key}="):
lines[i] = f"{key}={value}\n"
updated = True
break

if not updated:
lines.append(f"{key}={value}\n")

with open(env_path, "w") as f:
f.writelines(lines)

# Update current process
os.environ[key] = value
print(f"Updated .env: {key}={value}")
Loading