From 1044d172e64fbbd15cf0af90a3086de9b18c2d34 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 30 Jan 2026 16:47:07 +0100 Subject: [PATCH 01/15] refactor + add appsec data to span --- .../trace/lambda/LambdaAppSecHandler.java | 427 +++++------------- 1 file changed, 121 insertions(+), 306 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 32bfa7fad6e..9417ba213d1 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -4,8 +4,6 @@ import com.squareup.moshi.JsonAdapter; import com.squareup.moshi.Moshi; -import datadog.logging.RatelimitedLogger; -import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; import datadog.trace.api.gateway.Flow; @@ -23,12 +21,12 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; -import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -36,27 +34,19 @@ import org.slf4j.LoggerFactory; /** - * Handles AppSec processing for AWS Lambda invocations. Extracts Lambda event data and invokes - * AppSec gateway callbacks. + * Handles AppSec processing for AWS Lambda invocations. + * Extracts Lambda event data and invokes AppSec gateway callbacks. */ public class LambdaAppSecHandler { private static final Logger log = LoggerFactory.getLogger(LambdaAppSecHandler.class); - private static final RatelimitedLogger rlLog = new RatelimitedLogger(log, 5, TimeUnit.MINUTES); - - private static final Moshi MOSHI = new Moshi.Builder().build(); - private static final JsonAdapter MAP_ADAPTER = MOSHI.adapter(Map.class); - private static final JsonAdapter OBJECT_ADAPTER = MOSHI.adapter(Object.class); - - private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); /** - * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes - * all relevant AppSec gateway callbacks. + * Process AppSec request data at the start of a Lambda invocation. + * Extract event data and invokes all relevant AppSec gateway callbacks. * * @param event the Lambda event object - * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing - * fails + * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing fails */ public static AgentSpanContext processRequestStart(Object event) { if (!ActiveSubsystems.APPSEC_ACTIVE) { @@ -65,20 +55,15 @@ public static AgentSpanContext processRequestStart(Object event) { } if (!(event instanceof ByteArrayInputStream)) { - log.debug( - "Event is not a ByteArrayInputStream, type: {}", - event != null ? event.getClass().getName() : "null"); + log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); return null; } try { LambdaEventData eventData = extractEventData((ByteArrayInputStream) event); - if (eventData == LambdaEventData.EMPTY) { - return null; - } return processAppSecRequestData(eventData); } catch (Exception e) { - log.debug("Failed to process AppSec request data", e); + log.error("Failed to process AppSec request data", e); return null; } } @@ -97,11 +82,12 @@ public static void processRequestEnd(AgentSpan span) { if (requestContext != null) { AgentTracer.TracerAPI tracer = AgentTracer.get(); BiFunction> requestEndedCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestEnded()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestEnded()); if (requestEndedCallback != null) { requestEndedCallback.apply(requestContext, span); } else { - log.debug("requestEnded callback is null"); + log.warn("requestEnded callback is null"); } } } @@ -125,16 +111,20 @@ public static AgentSpanContext mergeContexts( if (appSecContext instanceof TagContext) { TagContext extracted = (TagContext) appSecContext; Object appSecData = extracted.getRequestContextDataAppSec(); + Object iastData = extracted.getRequestContextDataIast(); if (extensionContext instanceof TagContext) { TagContext merged = (TagContext) extensionContext; if (appSecData != null) { merged.withRequestContextDataAppSec(appSecData); } + if (iastData != null) { + merged.withRequestContextDataIast(iastData); + } return merged; } - rlLog.warn( + log.warn( "Cannot merge AppSec data: extension context is not a TagContext: {}", extensionContext.getClass()); } @@ -146,7 +136,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa Supplier> requestStartedCallback = tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestStarted()); if (requestStartedCallback == null) { - log.debug("requestStarted callback is null"); + log.warn("requestStarted callback is null"); return null; } @@ -162,48 +152,38 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestMethodUriRaw if (eventData.method != null && eventData.path != null) { - datadog.trace.api.function.TriFunction> - methodUriCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestMethodUriRaw()); + datadog.trace.api.function.TriFunction> methodUriCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); if (methodUriCallback != null) { - // Reconstruct full path with query string for AppSec analysis - String fullPath = buildFullPath(eventData.path, eventData.queryParameters); - LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(fullPath, eventData.headers); + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); methodUriCallback.apply(requestContext, eventData.method, uriAdapter); } else { - log.debug("requestMethodUriRaw callback is null"); + log.warn("requestMethodUriRaw callback is null"); } } // Call requestHeader for each header if (eventData.headers != null && !eventData.headers.isEmpty()) { TriConsumer headerCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestHeader()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); if (headerCallback != null) { for (Map.Entry header : eventData.headers.entrySet()) { headerCallback.accept(requestContext, header.getKey(), header.getValue()); } } else { - log.debug("requestHeader callback is null"); + log.warn("requestHeader callback is null"); } } // Call requestClientSocketAddress if (eventData.sourceIp != null) { - datadog.trace.api.function.TriFunction> - socketAddrCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestClientSocketAddress()); + datadog.trace.api.function.TriFunction> socketAddrCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); if (socketAddrCallback != null) { Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; socketAddrCallback.apply(requestContext, eventData.sourceIp, port); } else { - log.debug("requestClientSocketAddress callback is null"); + log.warn("requestClientSocketAddress callback is null"); } } @@ -215,32 +195,28 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa if (headerDoneCallback != null) { headerDoneCallback.apply(requestContext); } else { - log.debug("requestHeaderDone callback is null"); + log.warn("requestHeaderDone callback is null"); } // Call requestPathParams if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { BiFunction, Flow> pathParamsCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestPathParams()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); if (pathParamsCallback != null) { pathParamsCallback.apply(requestContext, eventData.pathParameters); } else { - log.debug("requestPathParams callback is null"); + log.warn("requestPathParams callback is null"); } } // Call requestBodyProcessed if (eventData.body != null) { BiFunction> bodyCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestBodyProcessed()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); if (bodyCallback != null) { bodyCallback.apply(requestContext, eventData.body); } else { - log.debug("requestBodyProcessed callback is null"); + log.warn("requestBodyProcessed callback is null"); } } } @@ -249,24 +225,16 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) throws IOException { - inputStream.mark(0); try { - int availableBytes = inputStream.available(); - - if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { - log.debug( - "Event size {} exceeds limit {} or is invalid, skipping AppSec processing", - availableBytes, - MAX_EVENT_SIZE); - return LambdaEventData.EMPTY; - } - - byte[] bytes = new byte[availableBytes]; - int read = inputStream.read(bytes); - if (read <= 0) { - return LambdaEventData.EMPTY; + StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); + try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { + char[] buffer = new char[1024]; + int charsRead; + while ((charsRead = reader.read(buffer)) != -1) { + jsonBuilder.append(buffer, 0, charsRead); + } } - return extractEventDataFromJson(new String(bytes, 0, read, StandardCharsets.UTF_8)); + return extractEventDataFromJson(jsonBuilder.toString()); } finally { inputStream.reset(); } @@ -275,11 +243,14 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream private static LambdaEventData extractEventDataFromJson(String json) { try { // Parse JSON into a Map - Map event = MAP_ADAPTER.fromJson(json); + JsonAdapter adapter = + new Moshi.Builder().build().adapter(Map.class); + + Map event = adapter.fromJson(json); log.debug("Event JSON parsed successfully"); if (event == null) { - return LambdaEventData.EMPTY; + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); } // Detect trigger type @@ -303,12 +274,12 @@ private static LambdaEventData extractEventDataFromJson(String json) { return extractGenericData(event); } } catch (Exception e) { - log.debug("Failed to parse event data from JSON", e); - return LambdaEventData.EMPTY; + log.error("Failed to parse event data from JSON", e); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); } } - static LambdaTriggerType detectTriggerType(Map event) { + private static LambdaTriggerType detectTriggerType(Map event) { Object requestContextObj = event.get("requestContext"); if (requestContextObj instanceof Map) { @@ -324,8 +295,8 @@ static LambdaTriggerType detectTriggerType(Map event) { } // Check for WebSocket - if (requestContext.containsKey("connectionId") - && (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { + if (requestContext.containsKey("connectionId") && + (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; } @@ -353,12 +324,12 @@ static LambdaTriggerType detectTriggerType(Map event) { return LambdaTriggerType.UNKNOWN; } - /** Extracts data from API Gateway v1 (REST API) event */ + /** + * Extracts data from API Gateway v1 (REST API) event + */ private static LambdaEventData extractApiGatewayV1Data(Map event) { Map headers = extractHeaders(event.get("headers")); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = - extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -372,25 +343,15 @@ private static LambdaEventData extractApiGatewayV1Data(Map event sourceIp = (String) identity.get("sourceIp"); } - return new LambdaEventData( - headers, - method, - path, - sourceIp, - null, - LambdaTriggerType.API_GATEWAY_V1_REST, - pathParameters, - queryParameters, - body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); } - /** Extracts data from API Gateway v2 (HTTP API) or Lambda URL event */ - private static LambdaEventData extractApiGatewayV2HttpData( - Map event, LambdaTriggerType triggerType) { + /** + * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event + */ + private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = - extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -407,24 +368,15 @@ private static LambdaEventData extractApiGatewayV2HttpData( sourcePort = ((Number) portObj).intValue(); } - return new LambdaEventData( - headers, - method, - path, - sourceIp, - sourcePort, - triggerType, - pathParameters, - queryParameters, - body); + return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); } - /** Extracts data from API Gateway v2 WebSocket event */ + /** + * Extracts data from API Gateway v2 WebSocket event + */ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = - extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -440,21 +392,13 @@ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event, LambdaTriggerType triggerType) { + /** + * Extracts data from ALB event (with or without multi-value headers) + */ + private static LambdaEventData extractAlbData(Map event, LambdaTriggerType triggerType) { Map headers; if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { @@ -469,10 +413,9 @@ private static LambdaEventData extractAlbData( if (entry.getValue() instanceof java.util.List) { java.util.List values = (java.util.List) entry.getValue(); // Join multiple values with comma - String joinedValue = - values.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining(", ")); + String joinedValue = values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); headers.put(key, joinedValue); } else { headers.put(key, String.valueOf(entry.getValue())); @@ -485,37 +428,21 @@ private static LambdaEventData extractAlbData( } Map pathParameters = extractPathParameters(event.get("pathParameters")); - - // ALB can have both queryStringParameters and multiValueQueryStringParameters - Map> queryParameters; - if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { - queryParameters = - extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); - } else { - queryParameters = extractQueryParameters(event.get("queryStringParameters")); - } - Object body = extractBody(event); String method = (String) event.get("httpMethod"); String path = (String) event.get("path"); - String xff = headers.get("x-forwarded-for"); - String sourceIp = null; - if (xff != null) { - int commaIdx = xff.indexOf(','); - sourceIp = (commaIdx >= 0 ? xff.substring(0, commaIdx) : xff).trim(); - } + String sourceIp = headers.get("x-forwarded-for"); - return new LambdaEventData( - headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); } - /** Generic data extraction for unknown trigger types (fallback) */ + /** + * Generic data extraction for unknown trigger types (fallback) + */ private static LambdaEventData extractGenericData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = - extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); String method = null; @@ -561,21 +488,12 @@ private static LambdaEventData extractGenericData(Map event) { } } - return new LambdaEventData( - headers, - method, - path, - sourceIp, - null, - LambdaTriggerType.UNKNOWN, - pathParameters, - queryParameters, - body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); } /** - * Generic helper method to extract string key-value pairs from an object. Converts all keys and - * values to strings, filtering out null entries. + * Generic helper method to extract string key-value pairs from an object. + * Converts all keys and values to strings, filtering out null entries. */ private static Map extractStringMap(Object mapObj) { Map result = new java.util.HashMap<>(); @@ -592,7 +510,9 @@ private static Map extractStringMap(Object mapObj) { return result; } - /** Helper method to extract headers from event */ + /** + * Helper method to extract headers from event + */ private static Map extractHeaders(Object headersObj) { Map headers = extractStringMap(headersObj); log.debug("Extracted {} headers", headers.size()); @@ -602,7 +522,9 @@ private static Map extractHeaders(Object headersObj) { return headers; } - /** Helper method to extract path parameters from event */ + /** + * Helper method to extract path parameters from event + */ private static Map extractPathParameters(Object pathParamsObj) { Map pathParams = extractStringMap(pathParamsObj); log.debug("Extracted {} path parameters", pathParams.size()); @@ -610,88 +532,8 @@ private static Map extractPathParameters(Object pathParamsObj) { } /** - * Helper method to extract query parameters from event. Converts Map to - * Map> format expected by AppSec. - */ - private static Map> extractQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); - if (queryParamsObj instanceof Map) { - Map rawMap = (Map) queryParamsObj; - for (Map.Entry entry : rawMap.entrySet()) { - if (entry.getKey() != null && entry.getValue() != null) { - String key = String.valueOf(entry.getKey()); - String value = String.valueOf(entry.getValue()); - result.put(key, java.util.Collections.singletonList(value)); - } - } - } - log.debug("Extracted {} query parameters", result.size()); - return result; - } - - /** - * Helper method to extract multi-value query parameters (used by ALB). Handles Map> format directly. - */ - private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); - if (queryParamsObj instanceof Map) { - Map rawMap = (Map) queryParamsObj; - for (Map.Entry entry : rawMap.entrySet()) { - if (entry.getKey() != null && entry.getValue() != null) { - String key = String.valueOf(entry.getKey()); - if (entry.getValue() instanceof java.util.List) { - java.util.List values = (java.util.List) entry.getValue(); - java.util.List stringValues = new java.util.ArrayList<>(); - for (Object value : values) { - if (value != null) { - stringValues.add(String.valueOf(value)); - } - } - result.put(key, stringValues); - } else { - result.put(key, java.util.Collections.singletonList(String.valueOf(entry.getValue()))); - } - } - } - } - log.debug("Extracted {} multi-value query parameters", result.size()); - return result; - } - - /** - * Helper method to build full path including query string. Lambda events provide path and query - * parameters separately, so we need to reconstruct the full URI for AppSec to parse. - */ - private static String buildFullPath(String path, Map> queryParameters) { - if (queryParameters == null || queryParameters.isEmpty()) { - return path; - } - - StringBuilder fullPath = new StringBuilder(path); - fullPath.append('?'); - - boolean first = true; - for (Map.Entry> entry : queryParameters.entrySet()) { - String key = entry.getKey(); - for (String value : entry.getValue()) { - if (!first) { - fullPath.append('&'); - } - first = false; - fullPath.append(key); - if (value != null) { - fullPath.append('=').append(value); - } - } - } - - return fullPath.toString(); - } - - /** - * Helper method to extract and merge headers with cookies array from event. API Gateway v2 - * provides a separate 'cookies' array that should be merged with headers. + * Helper method to extract and merge headers with cookies array from event. + * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. */ private static Map extractHeadersWithCookies(Map event) { Map headers = extractHeaders(event.get("headers")); @@ -702,10 +544,9 @@ private static Map extractHeadersWithCookies(Map java.util.List cookiesList = (java.util.List) cookiesObj; if (!cookiesList.isEmpty()) { // Join cookies with "; " separator per RFC 6265 - String cookieValue = - cookiesList.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining("; ")); + String cookieValue = cookiesList.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining("; ")); // Merge with existing cookie header if present String existingCookie = headers.get("cookie"); @@ -720,7 +561,9 @@ private static Map extractHeadersWithCookies(Map return headers; } - /** Helper method to extract and parse body from event */ + /** + * Helper method to extract and parse body from event + */ private static Object extractBody(Map event) { Object bodyObj = event.get("body"); if (bodyObj == null) { @@ -752,21 +595,26 @@ private static Object extractBody(Map event) { return bodyString; } - /** Helper method to parse body as JSON */ + /** + * Helper method to parse body as JSON + */ private static Object parseBodyAsJson(String body) { if (body == null || body.isEmpty() || "null".equals(body)) { return null; } try { - return OBJECT_ADAPTER.fromJson(body); + JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); + Object parsed = adapter.fromJson(body); + return parsed; } catch (Exception e) { return null; } } /** - * Temporary RequestContext implementation to hold AppSecRequestContext before a span is created. + * Temporary RequestContext implementation to hold AppSecRequestContext + * before a span is created. */ private static class TemporaryRequestContext implements RequestContext { private final Object appSecRequestContext; @@ -819,19 +667,23 @@ public void close() { } } - /** Enum representing different AWS Lambda trigger types */ - enum LambdaTriggerType { - API_GATEWAY_V1_REST, // API Gateway REST API (v1) - API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) + /** + * Enum representing different AWS Lambda trigger types + */ + private enum LambdaTriggerType { + API_GATEWAY_V1_REST, // API Gateway REST API (v1) + API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket - ALB, // Application Load Balancer - ALB_MULTI_VALUE, // ALB with multi-value headers - LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger + ALB, // Application Load Balancer + ALB_MULTI_VALUE, // ALB with multi-value headers + LAMBDA_URL, // Lambda Function URL + UNKNOWN // Unknown or unsupported trigger } - /** Object for Lambda event data needed for AppSec processing */ - static class LambdaEventData { + /** + * Object for Lambda event data needed for AppSec processing + */ + private static class LambdaEventData { final Map headers; final String method; final String path; @@ -839,31 +691,9 @@ static class LambdaEventData { final Integer sourcePort; final LambdaTriggerType triggerType; final Map pathParameters; - final Map> queryParameters; final Object body; - static final LambdaEventData EMPTY = - new LambdaEventData( - Collections.emptyMap(), - null, - null, - null, - null, - LambdaTriggerType.UNKNOWN, - Collections.emptyMap(), - Collections.emptyMap(), - null); - - LambdaEventData( - Map headers, - String method, - String path, - String sourceIp, - Integer sourcePort, - LambdaTriggerType triggerType, - Map pathParameters, - Map> queryParameters, - Object body) { + LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { this.headers = headers; this.method = method; this.path = path; @@ -871,19 +701,18 @@ static class LambdaEventData { this.sourcePort = sourcePort; this.triggerType = triggerType; this.pathParameters = pathParameters; - this.queryParameters = queryParameters; this.body = body; } } - /** URIDataAdapter implementation for Lambda events. */ + /** + * URIDataAdapter implementation for Lambda events. + */ private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; private final String query; - private final String scheme; - private final int port; - LambdaURIDataAdapter(String pathWithQuery, Map headers) { + LambdaURIDataAdapter(String pathWithQuery) { if (pathWithQuery != null) { int queryIndex = pathWithQuery.indexOf('?'); if (queryIndex != -1) { @@ -897,25 +726,11 @@ private static class LambdaURIDataAdapter extends URIDataAdapterBase { this.path = "/"; this.query = null; } - - String forwardedProto = headers != null ? headers.get("x-forwarded-proto") : null; - this.scheme = - (forwardedProto != null && !forwardedProto.isEmpty()) ? forwardedProto : "https"; - - String forwardedPort = headers != null ? headers.get("x-forwarded-port") : null; - int parsedPort = -1; - if (forwardedPort != null && !forwardedPort.isEmpty()) { - try { - parsedPort = Integer.parseInt(forwardedPort.trim()); - } catch (NumberFormatException ignored) { - } - } - this.port = parsedPort > 0 ? parsedPort : 443; } @Override public String scheme() { - return scheme; + return "https"; } @Override @@ -925,7 +740,7 @@ public String host() { @Override public int port() { - return port; + return 443; } @Override From 1c6d3a3f87054c97930cc3643d147f20cc05b65f Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Thu, 5 Feb 2026 15:12:59 +0100 Subject: [PATCH 02/15] unit tests --- .../trace/lambda/LambdaAppSecHandler.java | 37 +- .../lambda/LambdaAppSecHandlerTest.groovy | 1339 +++++++++++++++++ 2 files changed, 1362 insertions(+), 14 deletions(-) create mode 100644 dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 9417ba213d1..7d05c43ad7b 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -4,6 +4,7 @@ import com.squareup.moshi.JsonAdapter; import com.squareup.moshi.Moshi; +import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; import datadog.trace.api.gateway.Flow; @@ -41,6 +42,12 @@ public class LambdaAppSecHandler { private static final Logger log = LoggerFactory.getLogger(LambdaAppSecHandler.class); + private static final Moshi MOSHI = new Moshi.Builder().build(); + private static final JsonAdapter MAP_ADAPTER = MOSHI.adapter(Map.class); + private static final JsonAdapter OBJECT_ADAPTER = MOSHI.adapter(Object.class); + + private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); + /** * Process AppSec request data at the start of a Lambda invocation. * Extract event data and invokes all relevant AppSec gateway callbacks. @@ -111,16 +118,12 @@ public static AgentSpanContext mergeContexts( if (appSecContext instanceof TagContext) { TagContext extracted = (TagContext) appSecContext; Object appSecData = extracted.getRequestContextDataAppSec(); - Object iastData = extracted.getRequestContextDataIast(); if (extensionContext instanceof TagContext) { TagContext merged = (TagContext) extensionContext; if (appSecData != null) { merged.withRequestContextDataAppSec(appSecData); } - if (iastData != null) { - merged.withRequestContextDataIast(iastData); - } return merged; } @@ -225,8 +228,18 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) throws IOException { + inputStream.mark(0); try { - StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); + int availableBytes = inputStream.available(); + + if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { + log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", + availableBytes, MAX_EVENT_SIZE); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, + LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + + StringBuilder jsonBuilder = new StringBuilder(availableBytes); try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { char[] buffer = new char[1024]; int charsRead; @@ -243,10 +256,7 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream private static LambdaEventData extractEventDataFromJson(String json) { try { // Parse JSON into a Map - JsonAdapter adapter = - new Moshi.Builder().build().adapter(Map.class); - - Map event = adapter.fromJson(json); + Map event = MAP_ADAPTER.fromJson(json); log.debug("Event JSON parsed successfully"); if (event == null) { @@ -279,7 +289,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } - private static LambdaTriggerType detectTriggerType(Map event) { + static LambdaTriggerType detectTriggerType(Map event) { Object requestContextObj = event.get("requestContext"); if (requestContextObj instanceof Map) { @@ -604,8 +614,7 @@ private static Object parseBodyAsJson(String body) { } try { - JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); - Object parsed = adapter.fromJson(body); + Object parsed = OBJECT_ADAPTER.fromJson(body); return parsed; } catch (Exception e) { return null; @@ -670,7 +679,7 @@ public void close() { /** * Enum representing different AWS Lambda trigger types */ - private enum LambdaTriggerType { + enum LambdaTriggerType { API_GATEWAY_V1_REST, // API Gateway REST API (v1) API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket @@ -683,7 +692,7 @@ private enum LambdaTriggerType { /** * Object for Lambda event data needed for AppSec processing */ - private static class LambdaEventData { + static class LambdaEventData { final Map headers; final String method; final String path; diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy new file mode 100644 index 00000000000..eea29fe68c8 --- /dev/null +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -0,0 +1,1339 @@ +package datadog.trace.lambda + +import datadog.trace.api.Config +import datadog.trace.api.function.TriConsumer +import datadog.trace.api.gateway.CallbackProvider +import datadog.trace.api.gateway.Flow +import datadog.trace.api.gateway.RequestContext +import datadog.trace.api.gateway.RequestContextSlot +import datadog.trace.bootstrap.ActiveSubsystems +import datadog.trace.bootstrap.instrumentation.api.AgentSpan +import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.bootstrap.instrumentation.api.TagContext +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter +import datadog.trace.core.test.DDCoreSpecification +import spock.lang.Shared + +import java.nio.charset.StandardCharsets +import java.util.function.BiFunction +import java.util.function.Function +import java.util.function.Supplier + +import static datadog.trace.api.gateway.Events.EVENTS + +class LambdaAppSecHandlerTest extends DDCoreSpecification { + + @Shared + def originalAppSecActive + + def setupSpec() { + originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE + } + + def cleanupSpec() { + ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive + } + + def setup() { + ActiveSubsystems.APPSEC_ACTIVE = true + } + + def "processRequestStart returns null when AppSec is disabled"() { + given: + ActiveSubsystems.APPSEC_ACTIVE = false + def event = createInputStream('{"test": "data"}') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for non-ByteArrayInputStream"() { + when: + def result = LambdaAppSecHandler.processRequestStart("not a stream") + + then: + result == null + } + + def "processRequestStart returns null for null event"() { + when: + def result = LambdaAppSecHandler.processRequestStart(null) + + then: + result == null + } + + def "processRequestStart returns null for oversized event"() { + given: + def maxSize = Config.get().getAppSecBodyParsingSizeLimit() + def largeBody = "x" * (maxSize + 1) + def event = createInputStream(largeBody) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for zero-size event"() { + given: + def event = createInputStream('') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for malformed JSON"() { + given: + def event = createInputStream('{invalid json') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "stream can be read multiple times after processing"() { + given: + def jsonData = '{"test": "data", "requestContext": {"httpMethod": "GET"}}' + def event = createInputStream(jsonData) + + when: + LambdaAppSecHandler.processRequestStart(event) + event.reset() + def content = new String(event.readAllBytes(), StandardCharsets.UTF_8) + + then: + content == jsonData + } + + + // ============================================================================ + // Trigger Type Detection Tests + // ============================================================================ + + def "detects API Gateway v1 REST trigger type"() { + given: + def event = [ + requestContext: [ + httpMethod: "GET", + requestId: "abc123" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST + } + + def "detects API Gateway v2 HTTP trigger type"() { + given: + def event = [ + requestContext: [ + http: [ + method: "POST", + path: "/api" + ], + domainName: "api.example.com" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP + } + + def "detects Lambda Function URL trigger type"() { + given: + def event = [ + requestContext: [ + http: [ + method: "GET", + path: "/" + ], + domainName: "xyz123.lambda-url.us-east-1.on.aws" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL + } + + def "detects ALB trigger type without multi-value headers"() { + given: + def event = [ + httpMethod: "GET", + path: "/", + requestContext: [ + elb: [ + targetGroupArn: "arn:aws:..." + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB + } + + def "detects ALB trigger type with multi-value headers"() { + given: + def event = [ + httpMethod: "GET", + path: "/", + multiValueHeaders: [ + accept: ["text/html", "application/json"] + ], + requestContext: [ + elb: [ + targetGroupArn: "arn:aws:..." + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE + } + + def "detects WebSocket trigger type with routeKey"() { + given: + def event = [ + requestContext: [ + connectionId: "conn-123", + routeKey: "\$connect" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET + } + + def "detects WebSocket trigger type with eventType"() { + given: + def event = [ + requestContext: [ + connectionId: "conn-456", + eventType: "CONNECT" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET + } + + def "detects unknown trigger type for unrecognized events"() { + given: + def event = [ + someUnknownField: "value" + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN + } + + def "detects unknown trigger type for empty requestContext"() { + given: + def event = [ + requestContext: [:] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN + } + + def "detects Lambda URL when http present but no domainName"() { + given: + def event = [ + requestContext: [ + http: [ + method: "GET", + path: "/ambiguous" + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL + } + + // ============================================================================ + // Data Extraction Tests with Mocked Callbacks + // ============================================================================ + + def "extracts API Gateway v1 REST data correctly"() { + given: + def eventJson = ''' + { + "path": "/api/users/123", + "httpMethod": "POST", + "headers": { + "Content-Type": "application/json", + "Authorization": "Bearer token123" + }, + "pathParameters": { + "userId": "123" + }, + "body": "{\\"name\\": \\"John\\"}", + "requestContext": { + "httpMethod": "POST", + "requestId": "req-123", + "identity": { + "sourceIp": "192.168.1.100" + } + } + } + ''' + def event = createInputStream(eventJson) + + // Track callback invocations + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + def capturedSourcePort = null + def capturedPathParams = null + def capturedBody = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + }, + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + result instanceof TagContext + + capturedMethod == "POST" + capturedPath == "/api/users/123" + capturedHeaders["Content-Type"] == "application/json" + capturedHeaders["Authorization"] == "Bearer token123" + capturedSourceIp == "192.168.1.100" + capturedSourcePort == 0 + capturedPathParams == ["userId": "123"] + capturedBody instanceof Map + capturedBody.name == "John" + } + + def "extracts API Gateway v2 HTTP data correctly"() { + given: + def eventJson = ''' + { + "version": "2.0", + "headers": { + "content-type": "application/json", + "x-custom-header": "custom-value" + }, + "cookies": ["session=abc123", "user=john"], + "pathParameters": { + "id": "456" + }, + "body": "test body", + "requestContext": { + "http": { + "method": "PUT", + "path": "/api/items/456", + "sourceIp": "10.0.0.50", + "sourcePort": 54321 + }, + "domainName": "api.example.com" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + def capturedSourcePort = null + def capturedPathParams = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "PUT" + capturedPath == "/api/items/456" + capturedHeaders["content-type"] == "application/json" + capturedHeaders["x-custom-header"] == "custom-value" + capturedHeaders["cookie"] == "session=abc123; user=john" + capturedSourceIp == "10.0.0.50" + capturedSourcePort == 54321 + capturedPathParams == ["id": "456"] + } + + def "extracts Lambda Function URL data correctly"() { + given: + def eventJson = ''' + { + "version": "2.0", + "headers": { + "host": "xyz.lambda-url.us-east-1.on.aws" + }, + "requestContext": { + "http": { + "method": "GET", + "path": "/function/path", + "sourceIp": "1.2.3.4" + }, + "domainName": "xyz.lambda-url.us-east-1.on.aws" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "GET" + capturedPath == "/function/path" + } + + def "extracts ALB data correctly"() { + given: + def eventJson = ''' + { + "path": "/alb/test", + "httpMethod": "DELETE", + "headers": { + "x-forwarded-for": "203.0.113.42", + "user-agent": "curl/7.64.1" + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "DELETE" + capturedPath == "/alb/test" + capturedSourceIp == "203.0.113.42" + } + + def "extracts ALB multi-value headers correctly"() { + given: + def eventJson = ''' + { + "path": "/test", + "httpMethod": "GET", + "multiValueHeaders": { + "accept": ["text/html", "application/json"], + "x-custom": ["value1", "value2"] + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:..." + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["accept"] == "text/html, application/json" + capturedHeaders["x-custom"] == "value1, value2" + } + + def "handles multi-value headers with empty list"() { + given: + def eventJson = ''' + { + "path": "/test", + "httpMethod": "GET", + "multiValueHeaders": { + "accept": [], + "x-custom": ["value1"] + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:..." + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["accept"] == "" // Empty list should result in empty string + capturedHeaders["x-custom"] == "value1" + } + + def "extracts WebSocket data correctly"() { + given: + def eventJson = ''' + { + "requestContext": { + "routeKey": "$connect", + "connectionId": "conn-abc123", + "identity": { + "sourceIp": "192.168.0.100" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "WEBSOCKET" + capturedPath == "\$connect" + capturedSourceIp == "192.168.0.100" + } + + def "handles base64 encoded body correctly"() { + given: + def originalBody = "This is test data" + def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()) + def eventJson = """ + { + "body": "${base64Body}", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + """ + def event = createInputStream(eventJson) + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == originalBody + } + + def "handles null body correctly"() { + given: + def event = createInputStream('{"body": null, "requestContext": {"httpMethod": "GET"}}') + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "NOT_CALLED" // Callback should not be invoked for null body + } + + def "handles empty body correctly"() { + given: + def event = createInputStream('{"body": "", "requestContext": {"httpMethod": "POST"}}') + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "" // Empty body is passed as empty string to WAF + } + + def "handles path with query string correctly"() { + given: + def eventJson = ''' + { + "path": "/api/users?id=123&filter=active", + "requestContext": { + "httpMethod": "GET" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedPath = null + def capturedQuery = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedPath = uri.path() + capturedQuery = uri.query() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedPath == "/api/users" + capturedQuery == "id=123&filter=active" + } + + def "handles invalid base64 body gracefully"() { + given: + def eventJson = ''' + { + "body": "not-valid-base64", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "NOT_CALLED" // Should not call body callback when decode fails + } + + def "handles base64 decoded empty string body"() { + given: + def base64Empty = Base64.getEncoder().encodeToString("".getBytes()) + def eventJson = """ + { + "body": "${base64Empty}", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + """ + def event = createInputStream(eventJson) + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "" // Should pass empty string after decoding + } + + def "handles body with special characters"() { + given: + def eventJson = ''' + { + "body": "{\\"text\\": \\"Hello δΈ–η•Œ 🌍\\"}", + "requestContext": { + "httpMethod": "POST" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody instanceof Map + capturedBody.text == "Hello δΈ–η•Œ 🌍" + } + + // ============================================================================ + // Generic Data Extraction Tests + // ============================================================================ + + def "extracts data from unknown trigger type using generic extraction"() { + given: + def eventJson = ''' + { + "path": "/generic/path", + "httpMethod": "PATCH", + "headers": { + "x-custom-header": "generic-value" + }, + "unknownField": "should be ignored", + "requestContext": { + "identity": { + "sourceIp": "203.0.113.1" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "PATCH" + capturedPath == "/generic/path" + capturedHeaders["x-custom-header"] == "generic-value" + capturedSourceIp == "203.0.113.1" + } + + def "extracts data from unknown trigger with http in requestContext"() { + given: + def eventJson = ''' + { + "requestContext": { + "http": { + "method": "OPTIONS", + "path": "/options/path", + "sourceIp": "198.51.100.50" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "OPTIONS" + capturedPath == "/options/path" + capturedSourceIp == "198.51.100.50" + } + + def "handles cookies merging with existing cookie header"() { + given: + def eventJson = ''' + { + "headers": { + "cookie": "existing=value" + }, + "cookies": ["new=cookie1", "another=cookie2"], + "requestContext": { + "http": { + "method": "GET", + "path": "/" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["cookie"] == "existing=value; new=cookie1; another=cookie2" + } + + def "handles empty cookies array correctly"() { + given: + def eventJson = ''' + { + "headers": { + "content-type": "application/json" + }, + "cookies": [], + "requestContext": { + "http": { + "method": "GET", + "path": "/" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + !capturedHeaders.containsKey("cookie") // Empty array should not add cookie header + } + + // ============================================================================ + // processRequestEnd Tests + // ============================================================================ + + def "processRequestEnd does nothing when span is null"() { + when: + LambdaAppSecHandler.processRequestEnd(null) + + then: + noExceptionThrown() + } + + def "processRequestEnd does nothing when AppSec is disabled"() { + given: + ActiveSubsystems.APPSEC_ACTIVE = false + def span = Mock(AgentSpan) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + 0 * span._ + } + + def "processRequestEnd does nothing when span has no RequestContext"() { + given: + def span = Mock(AgentSpan) { + getRequestContext() >> null + } + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + noExceptionThrown() + } + + def "processRequestEnd invokes requestEnded callback with RequestContext"() { + given: + def mockAppSecContext = new Object() + def mockRequestContext = Mock(RequestContext) { + getData(RequestContextSlot.APPSEC) >> mockAppSecContext + } + def span = Mock(AgentSpan) { + getRequestContext() >> mockRequestContext + } + + def callbackInvoked = false + def capturedContext = null + def capturedSpan = null + + def mockRequestEndedCallback = Mock(BiFunction) { + apply(_ as RequestContext, _ as AgentSpan) >> { RequestContext ctx, AgentSpan s -> + callbackInvoked = true + capturedContext = ctx + capturedSpan = s + return new Flow.ResultFlow<>(null) + } + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestEnded()) >> mockRequestEndedCallback + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + callbackInvoked + capturedContext == mockRequestContext + capturedSpan == span + } + + def "processRequestEnd handles null requestEnded callback gracefully"() { + given: + def mockRequestContext = Mock(RequestContext) + def span = Mock(AgentSpan) { + getRequestContext() >> mockRequestContext + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestEnded()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + noExceptionThrown() // Should log warning but not throw + } + + // ============================================================================ + // mergeContexts Tests + // ============================================================================ + + def "mergeContexts returns null when both contexts are null"() { + when: + def result = LambdaAppSecHandler.mergeContexts(null, null) + + then: + result == null + } + + def "mergeContexts returns extensionContext when appSecContext is null"() { + given: + def extensionContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, null) + + then: + result == extensionContext + } + + def "mergeContexts returns appSecContext when extensionContext is null"() { + given: + def appSecContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(null, appSecContext) + + then: + result == appSecContext + } + + def "mergeContexts merges AppSec data into TagContext"() { + given: + def appSecData = new Object() + + // Create real TagContext instances since methods are final + def appSecContext = new TagContext() + appSecContext.withRequestContextDataAppSec(appSecData) + + def extensionContext = new TagContext() + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + result.getRequestContextDataAppSec() == appSecData + } + + def "mergeContexts returns extensionContext when appSecContext is not TagContext"() { + given: + def extensionContext = Mock(TagContext) + def appSecContext = Mock(AgentSpanContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + } + + def "mergeContexts returns extensionContext when it is not TagContext"() { + given: + def extensionContext = Mock(AgentSpanContext) + def appSecContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + } + + // ============================================================================ + // Error Handling and Null Callback Tests + // ============================================================================ + + def "processRequestStart handles null requestStarted callback gracefully"() { + given: + def eventJson = '{"requestContext": {"httpMethod": "GET"}}' + def event = createInputStream(eventJson) + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null // Should return null when requestStarted callback is missing + } + + def "processRequestStart handles null methodUri callback gracefully"() { + given: + def eventJson = ''' + { + "path": "/test", + "requestContext": { + "httpMethod": "GET" + } + } + ''' + def event = createInputStream(eventJson) + + def mockAppSecContext = new Object() + + def mockRequestStartedCallback = Mock(Supplier) { + get() >> new Flow.ResultFlow<>(mockAppSecContext) + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback + getCallback(EVENTS.requestMethodUriRaw()) >> null // Null callback + getCallback(EVENTS.requestHeader()) >> null + getCallback(EVENTS.requestClientSocketAddress()) >> null + getCallback(EVENTS.requestHeaderDone()) >> Mock(Function) { + apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) + } + getCallback(EVENTS.requestPathParams()) >> null + getCallback(EVENTS.requestBodyProcessed()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null // Should continue processing even if methodUri callback is null + result instanceof TagContext + } + + def "processRequestStart handles exception during JSON parsing"() { + given: + def invalidJson = '{this is not valid JSON at all' + def event = createInputStream(invalidJson) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null // Should return null on parse error + } + + def "processRequestStart handles exception during stream reading"() { + given: + def mockStream = Mock(ByteArrayInputStream) { + available() >> { throw new IOException("Stream error") } + } + + when: + def result = LambdaAppSecHandler.processRequestStart(mockStream) + + then: + result == null // Should return null on IO error + } + + // ============================================================================ + // Helper Methods + // ============================================================================ + + private ByteArrayInputStream createInputStream(String json) { + return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + } + + /** + * Set up mock callbacks to capture invocations and verify data extraction. + * This mocks the AgentTracer and callback provider to intercept gateway calls. + */ + private void setupMockCallbacks(Map callbacks) { + def mockAppSecContext = new Object() + + def mockRequestStartedCallback = Mock(Supplier) { + get() >> new Flow.ResultFlow<>(mockAppSecContext) + } + + def mockMethodUriCallback = callbacks.onMethodUri ? Mock(datadog.trace.api.function.TriFunction) { + apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { RequestContext ctx, String method, URIDataAdapter uri -> + callbacks.onMethodUri(method, uri) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { + accept(_ as RequestContext, _ as String, _ as String) >> { RequestContext ctx, String name, String value -> + callbacks.onHeader(name, value) + } + } : null + + def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { + apply(_ as RequestContext, _ as String, _ as Integer) >> { RequestContext ctx, String ip, Integer port -> + callbacks.onSocketAddress(ip, port) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockHeaderDoneCallback = Mock(Function) { + apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) + } + + def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + callbacks.onPathParams(params) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockQueryParamsCallback = callbacks.onQueryParams ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + callbacks.onQueryParams(params) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Object) >> { RequestContext ctx, Object body -> + callbacks.onBody(body) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback + getCallback(EVENTS.requestMethodUriRaw()) >> mockMethodUriCallback + getCallback(EVENTS.requestHeader()) >> mockHeaderCallback + getCallback(EVENTS.requestClientSocketAddress()) >> mockSocketAddressCallback + getCallback(EVENTS.requestHeaderDone()) >> mockHeaderDoneCallback + getCallback(EVENTS.requestPathParams()) >> mockPathParamsCallback + getCallback(EVENTS.requestBodyProcessed()) >> mockBodyCallback + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + // Install the mock tracer + AgentTracer.forceRegister(mockTracer) + } + + def cleanup() { + // Reset tracer after each test + AgentTracer.forceRegister(null) + } +} From e4da8b36e9658f6fec73ce9c5d8e17dc487603d3 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Wed, 11 Feb 2026 16:59:13 +0100 Subject: [PATCH 03/15] add better support for query parameters --- .../trace/lambda/LambdaAppSecHandler.java | 119 ++++++++++++++++-- .../lambda/LambdaAppSecHandlerTest.groovy | 1 + 2 files changed, 110 insertions(+), 10 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 7d05c43ad7b..c4151b8c228 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -158,7 +159,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa datadog.trace.api.function.TriFunction> methodUriCallback = tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); if (methodUriCallback != null) { - LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); + // Reconstruct full path with query string for AppSec analysis + String fullPath = buildFullPath(eventData.path, eventData.queryParameters); + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(fullPath); methodUriCallback.apply(requestContext, eventData.method, uriAdapter); } else { log.warn("requestMethodUriRaw callback is null"); @@ -236,7 +239,7 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", availableBytes, MAX_EVENT_SIZE); return new LambdaEventData(Collections.emptyMap(), null, null, null, null, - LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } StringBuilder jsonBuilder = new StringBuilder(availableBytes); @@ -260,7 +263,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { log.debug("Event JSON parsed successfully"); if (event == null) { - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } // Detect trigger type @@ -285,7 +288,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } catch (Exception e) { log.error("Failed to parse event data from JSON", e); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } } @@ -340,6 +343,7 @@ static LambdaTriggerType detectTriggerType(Map event) { private static LambdaEventData extractApiGatewayV1Data(Map event) { Map headers = extractHeaders(event.get("headers")); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -353,7 +357,7 @@ private static LambdaEventData extractApiGatewayV1Data(Map event sourceIp = (String) identity.get("sourceIp"); } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, queryParameters, body); } /** @@ -362,6 +366,7 @@ private static LambdaEventData extractApiGatewayV1Data(Map event private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -378,7 +383,7 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e sourcePort = ((Number) portObj).intValue(); } - return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, queryParameters, body); } /** @@ -387,6 +392,7 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -402,7 +408,7 @@ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event, LambdaT } Map pathParameters = extractPathParameters(event.get("pathParameters")); + + // ALB can have both queryStringParameters and multiValueQueryStringParameters + Map> queryParameters; + if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { + queryParameters = extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); + } else { + queryParameters = extractQueryParameters(event.get("queryStringParameters")); + } + Object body = extractBody(event); String method = (String) event.get("httpMethod"); String path = (String) event.get("path"); String sourceIp = headers.get("x-forwarded-for"); - return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); } /** @@ -453,6 +468,7 @@ private static LambdaEventData extractAlbData(Map event, LambdaT private static LambdaEventData extractGenericData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); String method = null; @@ -498,7 +514,7 @@ private static LambdaEventData extractGenericData(Map event) { } } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, queryParameters, body); } /** @@ -541,6 +557,87 @@ private static Map extractPathParameters(Object pathParamsObj) { return pathParams; } + /** + * Helper method to extract query parameters from event. + * Converts Map to Map> format expected by AppSec. + */ + private static Map> extractQueryParameters(Object queryParamsObj) { + Map> result = new java.util.HashMap<>(); + if (queryParamsObj instanceof Map) { + Map rawMap = (Map) queryParamsObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + String value = String.valueOf(entry.getValue()); + result.put(key, java.util.Collections.singletonList(value)); + } + } + } + log.debug("Extracted {} query parameters", result.size()); + return result; + } + + /** + * Helper method to extract multi-value query parameters (used by ALB). + * Handles Map> format directly. + */ + private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { + Map> result = new java.util.HashMap<>(); + if (queryParamsObj instanceof Map) { + Map rawMap = (Map) queryParamsObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + if (entry.getValue() instanceof java.util.List) { + java.util.List values = (java.util.List) entry.getValue(); + java.util.List stringValues = new java.util.ArrayList<>(); + for (Object value : values) { + if (value != null) { + stringValues.add(String.valueOf(value)); + } + } + result.put(key, stringValues); + } else { + result.put(key, java.util.Collections.singletonList(String.valueOf(entry.getValue()))); + } + } + } + } + log.debug("Extracted {} multi-value query parameters", result.size()); + return result; + } + + /** + * Helper method to build full path including query string. + * Lambda events provide path and query parameters separately, so we need to reconstruct + * the full URI for AppSec to parse. + */ + private static String buildFullPath(String path, Map> queryParameters) { + if (queryParameters == null || queryParameters.isEmpty()) { + return path; + } + + StringBuilder fullPath = new StringBuilder(path); + fullPath.append('?'); + + boolean first = true; + for (Map.Entry> entry : queryParameters.entrySet()) { + String key = entry.getKey(); + for (String value : entry.getValue()) { + if (!first) { + fullPath.append('&'); + } + first = false; + fullPath.append(key); + if (value != null) { + fullPath.append('=').append(value); + } + } + } + + return fullPath.toString(); + } + /** * Helper method to extract and merge headers with cookies array from event. * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. @@ -700,9 +797,10 @@ static class LambdaEventData { final Integer sourcePort; final LambdaTriggerType triggerType; final Map pathParameters; + final Map> queryParameters; final Object body; - LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { + LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Map> queryParameters, Object body) { this.headers = headers; this.method = method; this.path = path; @@ -710,6 +808,7 @@ static class LambdaEventData { this.sourcePort = sourcePort; this.triggerType = triggerType; this.pathParameters = pathParameters; + this.queryParameters = queryParameters; this.body = body; } } diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index eea29fe68c8..59f72c0a98e 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -2,6 +2,7 @@ package datadog.trace.lambda import datadog.trace.api.Config import datadog.trace.api.function.TriConsumer +import datadog.trace.api.function.TriFunction import datadog.trace.api.gateway.CallbackProvider import datadog.trace.api.gateway.Flow import datadog.trace.api.gateway.RequestContext From ec2be135b9de7d878eb5419716c58d1e33b4383b Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Wed, 11 Feb 2026 17:16:00 +0100 Subject: [PATCH 04/15] apply spotless --- .../trace/lambda/LambdaAppSecHandler.java | 265 +++++++++++------- .../lambda/LambdaAppSecHandlerTest.groovy | 221 ++++++++------- 2 files changed, 282 insertions(+), 204 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index c4151b8c228..10bb7c03f3a 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -36,8 +36,8 @@ import org.slf4j.LoggerFactory; /** - * Handles AppSec processing for AWS Lambda invocations. - * Extracts Lambda event data and invokes AppSec gateway callbacks. + * Handles AppSec processing for AWS Lambda invocations. Extracts Lambda event data and invokes + * AppSec gateway callbacks. */ public class LambdaAppSecHandler { @@ -50,11 +50,12 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); /** - * Process AppSec request data at the start of a Lambda invocation. - * Extract event data and invokes all relevant AppSec gateway callbacks. + * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes + * all relevant AppSec gateway callbacks. * * @param event the Lambda event object - * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing fails + * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing + * fails */ public static AgentSpanContext processRequestStart(Object event) { if (!ActiveSubsystems.APPSEC_ACTIVE) { @@ -63,7 +64,9 @@ public static AgentSpanContext processRequestStart(Object event) { } if (!(event instanceof ByteArrayInputStream)) { - log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); + log.debug( + "Event is not a ByteArrayInputStream, type: {}", + event != null ? event.getClass().getName() : "null"); return null; } @@ -90,8 +93,7 @@ public static void processRequestEnd(AgentSpan span) { if (requestContext != null) { AgentTracer.TracerAPI tracer = AgentTracer.get(); BiFunction> requestEndedCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestEnded()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestEnded()); if (requestEndedCallback != null) { requestEndedCallback.apply(requestContext, span); } else { @@ -156,8 +158,11 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestMethodUriRaw if (eventData.method != null && eventData.path != null) { - datadog.trace.api.function.TriFunction> methodUriCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); + datadog.trace.api.function.TriFunction> + methodUriCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestMethodUriRaw()); if (methodUriCallback != null) { // Reconstruct full path with query string for AppSec analysis String fullPath = buildFullPath(eventData.path, eventData.queryParameters); @@ -171,7 +176,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestHeader for each header if (eventData.headers != null && !eventData.headers.isEmpty()) { TriConsumer headerCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestHeader()); if (headerCallback != null) { for (Map.Entry header : eventData.headers.entrySet()) { headerCallback.accept(requestContext, header.getKey(), header.getValue()); @@ -183,8 +190,11 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestClientSocketAddress if (eventData.sourceIp != null) { - datadog.trace.api.function.TriFunction> socketAddrCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); + datadog.trace.api.function.TriFunction> + socketAddrCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestClientSocketAddress()); if (socketAddrCallback != null) { Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; socketAddrCallback.apply(requestContext, eventData.sourceIp, port); @@ -207,7 +217,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestPathParams if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { BiFunction, Flow> pathParamsCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestPathParams()); if (pathParamsCallback != null) { pathParamsCallback.apply(requestContext, eventData.pathParameters); } else { @@ -218,7 +230,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestBodyProcessed if (eventData.body != null) { BiFunction> bodyCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestBodyProcessed()); if (bodyCallback != null) { bodyCallback.apply(requestContext, eventData.body); } else { @@ -236,10 +250,20 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream int availableBytes = inputStream.available(); if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { - log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", - availableBytes, MAX_EVENT_SIZE); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, - LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + log.warn( + "Event size {} exceeds limit {} or is invalid, skipping AppSec processing", + availableBytes, + MAX_EVENT_SIZE); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } StringBuilder jsonBuilder = new StringBuilder(availableBytes); @@ -263,7 +287,16 @@ private static LambdaEventData extractEventDataFromJson(String json) { log.debug("Event JSON parsed successfully"); if (event == null) { - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } // Detect trigger type @@ -288,7 +321,16 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } catch (Exception e) { log.error("Failed to parse event data from JSON", e); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } } @@ -308,8 +350,8 @@ static LambdaTriggerType detectTriggerType(Map event) { } // Check for WebSocket - if (requestContext.containsKey("connectionId") && - (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { + if (requestContext.containsKey("connectionId") + && (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; } @@ -337,13 +379,12 @@ static LambdaTriggerType detectTriggerType(Map event) { return LambdaTriggerType.UNKNOWN; } - /** - * Extracts data from API Gateway v1 (REST API) event - */ + /** Extracts data from API Gateway v1 (REST API) event */ private static LambdaEventData extractApiGatewayV1Data(Map event) { Map headers = extractHeaders(event.get("headers")); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -357,16 +398,25 @@ private static LambdaEventData extractApiGatewayV1Data(Map event sourceIp = (String) identity.get("sourceIp"); } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + null, + LambdaTriggerType.API_GATEWAY_V1_REST, + pathParameters, + queryParameters, + body); } - /** - * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event - */ - private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { + /** Extracts data from API Gateway v2 (HTTP API) or Lambda URL event */ + private static LambdaEventData extractApiGatewayV2HttpData( + Map event, LambdaTriggerType triggerType) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -383,16 +433,24 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e sourcePort = ((Number) portObj).intValue(); } - return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + sourcePort, + triggerType, + pathParameters, + queryParameters, + body); } - /** - * Extracts data from API Gateway v2 WebSocket event - */ + /** Extracts data from API Gateway v2 WebSocket event */ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -408,13 +466,21 @@ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event, LambdaTriggerType triggerType) { + /** Extracts data from ALB event (with or without multi-value headers) */ + private static LambdaEventData extractAlbData( + Map event, LambdaTriggerType triggerType) { Map headers; if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { @@ -429,9 +495,10 @@ private static LambdaEventData extractAlbData(Map event, LambdaT if (entry.getValue() instanceof java.util.List) { java.util.List values = (java.util.List) entry.getValue(); // Join multiple values with comma - String joinedValue = values.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining(", ")); + String joinedValue = + values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); headers.put(key, joinedValue); } else { headers.put(key, String.valueOf(entry.getValue())); @@ -448,7 +515,8 @@ private static LambdaEventData extractAlbData(Map event, LambdaT // ALB can have both queryStringParameters and multiValueQueryStringParameters Map> queryParameters; if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { - queryParameters = extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); + queryParameters = + extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); } else { queryParameters = extractQueryParameters(event.get("queryStringParameters")); } @@ -459,16 +527,16 @@ private static LambdaEventData extractAlbData(Map event, LambdaT String path = (String) event.get("path"); String sourceIp = headers.get("x-forwarded-for"); - return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); } - /** - * Generic data extraction for unknown trigger types (fallback) - */ + /** Generic data extraction for unknown trigger types (fallback) */ private static LambdaEventData extractGenericData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); String method = null; @@ -514,12 +582,21 @@ private static LambdaEventData extractGenericData(Map event) { } } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + null, + LambdaTriggerType.UNKNOWN, + pathParameters, + queryParameters, + body); } /** - * Generic helper method to extract string key-value pairs from an object. - * Converts all keys and values to strings, filtering out null entries. + * Generic helper method to extract string key-value pairs from an object. Converts all keys and + * values to strings, filtering out null entries. */ private static Map extractStringMap(Object mapObj) { Map result = new java.util.HashMap<>(); @@ -536,9 +613,7 @@ private static Map extractStringMap(Object mapObj) { return result; } - /** - * Helper method to extract headers from event - */ + /** Helper method to extract headers from event */ private static Map extractHeaders(Object headersObj) { Map headers = extractStringMap(headersObj); log.debug("Extracted {} headers", headers.size()); @@ -548,9 +623,7 @@ private static Map extractHeaders(Object headersObj) { return headers; } - /** - * Helper method to extract path parameters from event - */ + /** Helper method to extract path parameters from event */ private static Map extractPathParameters(Object pathParamsObj) { Map pathParams = extractStringMap(pathParamsObj); log.debug("Extracted {} path parameters", pathParams.size()); @@ -558,8 +631,8 @@ private static Map extractPathParameters(Object pathParamsObj) { } /** - * Helper method to extract query parameters from event. - * Converts Map to Map> format expected by AppSec. + * Helper method to extract query parameters from event. Converts Map to + * Map> format expected by AppSec. */ private static Map> extractQueryParameters(Object queryParamsObj) { Map> result = new java.util.HashMap<>(); @@ -578,8 +651,8 @@ private static Map> extractQueryParameters(Object queryPara } /** - * Helper method to extract multi-value query parameters (used by ALB). - * Handles Map> format directly. + * Helper method to extract multi-value query parameters (used by ALB). Handles Map> format directly. */ private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { Map> result = new java.util.HashMap<>(); @@ -608,9 +681,8 @@ private static Map> extractMultiValueQueryParameters(Object } /** - * Helper method to build full path including query string. - * Lambda events provide path and query parameters separately, so we need to reconstruct - * the full URI for AppSec to parse. + * Helper method to build full path including query string. Lambda events provide path and query + * parameters separately, so we need to reconstruct the full URI for AppSec to parse. */ private static String buildFullPath(String path, Map> queryParameters) { if (queryParameters == null || queryParameters.isEmpty()) { @@ -639,8 +711,8 @@ private static String buildFullPath(String path, Map> query } /** - * Helper method to extract and merge headers with cookies array from event. - * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. + * Helper method to extract and merge headers with cookies array from event. API Gateway v2 + * provides a separate 'cookies' array that should be merged with headers. */ private static Map extractHeadersWithCookies(Map event) { Map headers = extractHeaders(event.get("headers")); @@ -651,9 +723,10 @@ private static Map extractHeadersWithCookies(Map java.util.List cookiesList = (java.util.List) cookiesObj; if (!cookiesList.isEmpty()) { // Join cookies with "; " separator per RFC 6265 - String cookieValue = cookiesList.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining("; ")); + String cookieValue = + cookiesList.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining("; ")); // Merge with existing cookie header if present String existingCookie = headers.get("cookie"); @@ -668,9 +741,7 @@ private static Map extractHeadersWithCookies(Map return headers; } - /** - * Helper method to extract and parse body from event - */ + /** Helper method to extract and parse body from event */ private static Object extractBody(Map event) { Object bodyObj = event.get("body"); if (bodyObj == null) { @@ -702,9 +773,7 @@ private static Object extractBody(Map event) { return bodyString; } - /** - * Helper method to parse body as JSON - */ + /** Helper method to parse body as JSON */ private static Object parseBodyAsJson(String body) { if (body == null || body.isEmpty() || "null".equals(body)) { return null; @@ -719,8 +788,7 @@ private static Object parseBodyAsJson(String body) { } /** - * Temporary RequestContext implementation to hold AppSecRequestContext - * before a span is created. + * Temporary RequestContext implementation to hold AppSecRequestContext before a span is created. */ private static class TemporaryRequestContext implements RequestContext { private final Object appSecRequestContext; @@ -773,22 +841,18 @@ public void close() { } } - /** - * Enum representing different AWS Lambda trigger types - */ + /** Enum representing different AWS Lambda trigger types */ enum LambdaTriggerType { - API_GATEWAY_V1_REST, // API Gateway REST API (v1) - API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) + API_GATEWAY_V1_REST, // API Gateway REST API (v1) + API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket - ALB, // Application Load Balancer - ALB_MULTI_VALUE, // ALB with multi-value headers - LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger + ALB, // Application Load Balancer + ALB_MULTI_VALUE, // ALB with multi-value headers + LAMBDA_URL, // Lambda Function URL + UNKNOWN // Unknown or unsupported trigger } - /** - * Object for Lambda event data needed for AppSec processing - */ + /** Object for Lambda event data needed for AppSec processing */ static class LambdaEventData { final Map headers; final String method; @@ -800,7 +864,16 @@ static class LambdaEventData { final Map> queryParameters; final Object body; - LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Map> queryParameters, Object body) { + LambdaEventData( + Map headers, + String method, + String path, + String sourceIp, + Integer sourcePort, + LambdaTriggerType triggerType, + Map pathParameters, + Map> queryParameters, + Object body) { this.headers = headers; this.method = method; this.path = path; @@ -813,9 +886,7 @@ static class LambdaEventData { } } - /** - * URIDataAdapter implementation for Lambda events. - */ + /** URIDataAdapter implementation for Lambda events. */ private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; private final String query; diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index 59f72c0a98e..00f58c7ef36 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -332,23 +332,23 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - }, - onBody: { body -> - capturedBody = body - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + }, + onBody: { body -> + capturedBody = body + } ) when: @@ -404,20 +404,20 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedPathParams = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + } ) when: @@ -459,10 +459,10 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedPath = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + } ) when: @@ -498,13 +498,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -539,9 +539,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -575,9 +575,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -609,13 +609,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -646,9 +646,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -666,9 +666,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -686,9 +686,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -715,10 +715,10 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedQuery = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedPath = uri.path() - capturedQuery = uri.query() - } + onMethodUri: { method, uri -> + capturedPath = uri.path() + capturedQuery = uri.query() + } ) when: @@ -746,9 +746,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -776,9 +776,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -804,9 +804,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -847,16 +847,16 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -890,13 +890,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -930,9 +930,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -964,9 +964,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -1029,7 +1029,8 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSpan = null def mockRequestEndedCallback = Mock(BiFunction) { - apply(_ as RequestContext, _ as AgentSpan) >> { RequestContext ctx, AgentSpan s -> + apply(_ as RequestContext, _ as AgentSpan) >> { + RequestContext ctx, AgentSpan s -> callbackInvoked = true capturedContext = ctx capturedSpan = s @@ -1271,20 +1272,23 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { } def mockMethodUriCallback = callbacks.onMethodUri ? Mock(datadog.trace.api.function.TriFunction) { - apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { RequestContext ctx, String method, URIDataAdapter uri -> + apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { + RequestContext ctx, String method, URIDataAdapter uri -> callbacks.onMethodUri(method, uri) return new Flow.ResultFlow<>(null) } } : null def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { - accept(_ as RequestContext, _ as String, _ as String) >> { RequestContext ctx, String name, String value -> + accept(_ as RequestContext, _ as String, _ as String) >> { + RequestContext ctx, String name, String value -> callbacks.onHeader(name, value) } } : null def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as Integer) >> { RequestContext ctx, String ip, Integer port -> + apply(_ as RequestContext, _ as String, _ as Integer) >> { + RequestContext ctx, String ip, Integer port -> callbacks.onSocketAddress(ip, port) return new Flow.ResultFlow<>(null) } @@ -1295,21 +1299,24 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { } def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + apply(_ as RequestContext, _ as Map) >> { + RequestContext ctx, Map params -> callbacks.onPathParams(params) return new Flow.ResultFlow<>(null) } } : null def mockQueryParamsCallback = callbacks.onQueryParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + apply(_ as RequestContext, _ as Map) >> { + RequestContext ctx, Map params -> callbacks.onQueryParams(params) return new Flow.ResultFlow<>(null) } } : null def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Object) >> { RequestContext ctx, Object body -> + apply(_ as RequestContext, _ as Object) >> { + RequestContext ctx, Object body -> callbacks.onBody(body) return new Flow.ResultFlow<>(null) } From 012235de45de6f6df0ffcdf573ae032f930740c6 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 27 Mar 2026 16:48:20 +0100 Subject: [PATCH 05/15] fix test crash --- .../groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index 00f58c7ef36..f508d1a0dd3 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -111,7 +111,7 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { when: LambdaAppSecHandler.processRequestStart(event) event.reset() - def content = new String(event.readAllBytes(), StandardCharsets.UTF_8) + def content = new String(event.bytes, StandardCharsets.UTF_8) then: content == jsonData From 9e5ffac9754dff6c0e14b625ed322215af981569 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 27 Mar 2026 17:19:55 +0100 Subject: [PATCH 06/15] remove unused var --- .../datadog/trace/lambda/LambdaAppSecHandlerTest.groovy | 8 -------- 1 file changed, 8 deletions(-) diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index f508d1a0dd3..822659b95f4 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -1306,14 +1306,6 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { } } : null - def mockQueryParamsCallback = callbacks.onQueryParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { - RequestContext ctx, Map params -> - callbacks.onQueryParams(params) - return new Flow.ResultFlow<>(null) - } - } : null - def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { apply(_ as RequestContext, _ as Object) >> { RequestContext ctx, Object body -> From 7c0bf91c185bdab4cebe458fd2ee4b658d87cdfa Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 10 Apr 2026 14:53:34 +0200 Subject: [PATCH 07/15] forwarded headers parsing + downgrade log level --- .../trace/lambda/LambdaAppSecHandler.java | 94 ++++++++------- .../lambda/LambdaAppSecHandlerTest.groovy | 107 +++++++++++++++++- 2 files changed, 151 insertions(+), 50 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 10bb7c03f3a..4d8b04268fd 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -72,9 +72,12 @@ public static AgentSpanContext processRequestStart(Object event) { try { LambdaEventData eventData = extractEventData((ByteArrayInputStream) event); + if (eventData == LambdaEventData.EMPTY) { + return null; + } return processAppSecRequestData(eventData); } catch (Exception e) { - log.error("Failed to process AppSec request data", e); + log.debug("Failed to process AppSec request data", e); return null; } } @@ -97,7 +100,7 @@ public static void processRequestEnd(AgentSpan span) { if (requestEndedCallback != null) { requestEndedCallback.apply(requestContext, span); } else { - log.warn("requestEnded callback is null"); + log.debug("requestEnded callback is null"); } } } @@ -130,7 +133,7 @@ public static AgentSpanContext mergeContexts( return merged; } - log.warn( + log.debug( "Cannot merge AppSec data: extension context is not a TagContext: {}", extensionContext.getClass()); } @@ -142,7 +145,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa Supplier> requestStartedCallback = tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestStarted()); if (requestStartedCallback == null) { - log.warn("requestStarted callback is null"); + log.debug("requestStarted callback is null"); return null; } @@ -166,10 +169,10 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa if (methodUriCallback != null) { // Reconstruct full path with query string for AppSec analysis String fullPath = buildFullPath(eventData.path, eventData.queryParameters); - LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(fullPath); + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(fullPath, eventData.headers); methodUriCallback.apply(requestContext, eventData.method, uriAdapter); } else { - log.warn("requestMethodUriRaw callback is null"); + log.debug("requestMethodUriRaw callback is null"); } } @@ -184,7 +187,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa headerCallback.accept(requestContext, header.getKey(), header.getValue()); } } else { - log.warn("requestHeader callback is null"); + log.debug("requestHeader callback is null"); } } @@ -199,7 +202,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; socketAddrCallback.apply(requestContext, eventData.sourceIp, port); } else { - log.warn("requestClientSocketAddress callback is null"); + log.debug("requestClientSocketAddress callback is null"); } } @@ -211,7 +214,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa if (headerDoneCallback != null) { headerDoneCallback.apply(requestContext); } else { - log.warn("requestHeaderDone callback is null"); + log.debug("requestHeaderDone callback is null"); } // Call requestPathParams @@ -223,7 +226,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa if (pathParamsCallback != null) { pathParamsCallback.apply(requestContext, eventData.pathParameters); } else { - log.warn("requestPathParams callback is null"); + log.debug("requestPathParams callback is null"); } } @@ -236,7 +239,7 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa if (bodyCallback != null) { bodyCallback.apply(requestContext, eventData.body); } else { - log.warn("requestBodyProcessed callback is null"); + log.debug("requestBodyProcessed callback is null"); } } } @@ -250,20 +253,11 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream int availableBytes = inputStream.available(); if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { - log.warn( + log.debug( "Event size {} exceeds limit {} or is invalid, skipping AppSec processing", availableBytes, MAX_EVENT_SIZE); - return new LambdaEventData( - Collections.emptyMap(), - null, - null, - null, - null, - LambdaTriggerType.UNKNOWN, - Collections.emptyMap(), - Collections.emptyMap(), - null); + return LambdaEventData.EMPTY; } StringBuilder jsonBuilder = new StringBuilder(availableBytes); @@ -287,16 +281,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { log.debug("Event JSON parsed successfully"); if (event == null) { - return new LambdaEventData( - Collections.emptyMap(), - null, - null, - null, - null, - LambdaTriggerType.UNKNOWN, - Collections.emptyMap(), - Collections.emptyMap(), - null); + return LambdaEventData.EMPTY; } // Detect trigger type @@ -320,17 +305,8 @@ private static LambdaEventData extractEventDataFromJson(String json) { return extractGenericData(event); } } catch (Exception e) { - log.error("Failed to parse event data from JSON", e); - return new LambdaEventData( - Collections.emptyMap(), - null, - null, - null, - null, - LambdaTriggerType.UNKNOWN, - Collections.emptyMap(), - Collections.emptyMap(), - null); + log.debug("Failed to parse event data from JSON", e); + return LambdaEventData.EMPTY; } } @@ -525,7 +501,8 @@ private static LambdaEventData extractAlbData( String method = (String) event.get("httpMethod"); String path = (String) event.get("path"); - String sourceIp = headers.get("x-forwarded-for"); + String xff = headers.get("x-forwarded-for"); + String sourceIp = xff != null ? xff.split(",")[0].trim() : null; return new LambdaEventData( headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); @@ -780,8 +757,7 @@ private static Object parseBodyAsJson(String body) { } try { - Object parsed = OBJECT_ADAPTER.fromJson(body); - return parsed; + return OBJECT_ADAPTER.fromJson(body); } catch (Exception e) { return null; } @@ -864,6 +840,11 @@ static class LambdaEventData { final Map> queryParameters; final Object body; + static final LambdaEventData EMPTY = new LambdaEventData( + Collections.emptyMap(), null, null, null, null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), Collections.emptyMap(), null); + LambdaEventData( Map headers, String method, @@ -890,8 +871,10 @@ static class LambdaEventData { private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; private final String query; + private final String scheme; + private final int port; - LambdaURIDataAdapter(String pathWithQuery) { + LambdaURIDataAdapter(String pathWithQuery, Map headers) { if (pathWithQuery != null) { int queryIndex = pathWithQuery.indexOf('?'); if (queryIndex != -1) { @@ -905,11 +888,24 @@ private static class LambdaURIDataAdapter extends URIDataAdapterBase { this.path = "/"; this.query = null; } + + String forwardedProto = headers != null ? headers.get("x-forwarded-proto") : null; + this.scheme = (forwardedProto != null && !forwardedProto.isEmpty()) ? forwardedProto : "https"; + + String forwardedPort = headers != null ? headers.get("x-forwarded-port") : null; + int parsedPort = -1; + if (forwardedPort != null && !forwardedPort.isEmpty()) { + try { + parsedPort = Integer.parseInt(forwardedPort.trim()); + } catch (NumberFormatException ignored) { + } + } + this.port = parsedPort > 0 ? parsedPort : 443; } @Override public String scheme() { - return "https"; + return scheme; } @Override @@ -919,7 +915,7 @@ public String host() { @Override public int port() { - return 443; + return port; } @Override diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index 822659b95f4..750cfe4ab78 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -730,6 +730,111 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { capturedQuery == "id=123&filter=active" } + def "extracts scheme and port from X-Forwarded headers"() { + given: + def eventJson = ''' + { + "path": "/api/test", + "headers": { + "x-forwarded-proto": "http", + "x-forwarded-port": "8080" + }, + "requestContext": { + "httpMethod": "GET", + "requestId": "req-123" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedScheme = null + def capturedPort = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedScheme = uri.scheme() + capturedPort = uri.port() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedScheme == "http" + capturedPort == 8080 + } + + def "falls back to https/443 when X-Forwarded headers are absent"() { + given: + def eventJson = ''' + { + "path": "/api/test", + "headers": {}, + "requestContext": { + "httpMethod": "GET", + "requestId": "req-123" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedScheme = null + def capturedPort = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedScheme = uri.scheme() + capturedPort = uri.port() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedScheme == "https" + capturedPort == 443 + } + + def "handles invalid X-Forwarded-Port gracefully"() { + given: + def eventJson = ''' + { + "path": "/api/test", + "headers": { + "x-forwarded-proto": "https", + "x-forwarded-port": "not-a-number" + }, + "requestContext": { + "httpMethod": "GET", + "requestId": "req-123" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedScheme = null + def capturedPort = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedScheme = uri.scheme() + capturedPort = uri.port() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedScheme == "https" + capturedPort == 443 + } + def "handles invalid base64 body gracefully"() { given: def eventJson = ''' @@ -1271,7 +1376,7 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { get() >> new Flow.ResultFlow<>(mockAppSecContext) } - def mockMethodUriCallback = callbacks.onMethodUri ? Mock(datadog.trace.api.function.TriFunction) { + def mockMethodUriCallback = callbacks.onMethodUri ? Mock(TriFunction) { apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { RequestContext ctx, String method, URIDataAdapter uri -> callbacks.onMethodUri(method, uri) From 9e64f6bf76d7ccfc7cc5872686bb3e3cc03cb50b Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 10 Apr 2026 15:02:01 +0200 Subject: [PATCH 08/15] formatting --- .../trace/lambda/LambdaAppSecHandler.java | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 4d8b04268fd..099527c5382 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -840,10 +840,17 @@ static class LambdaEventData { final Map> queryParameters; final Object body; - static final LambdaEventData EMPTY = new LambdaEventData( - Collections.emptyMap(), null, null, null, null, - LambdaTriggerType.UNKNOWN, - Collections.emptyMap(), Collections.emptyMap(), null); + static final LambdaEventData EMPTY = + new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); LambdaEventData( Map headers, @@ -890,7 +897,8 @@ private static class LambdaURIDataAdapter extends URIDataAdapterBase { } String forwardedProto = headers != null ? headers.get("x-forwarded-proto") : null; - this.scheme = (forwardedProto != null && !forwardedProto.isEmpty()) ? forwardedProto : "https"; + this.scheme = + (forwardedProto != null && !forwardedProto.isEmpty()) ? forwardedProto : "https"; String forwardedPort = headers != null ? headers.get("x-forwarded-port") : null; int parsedPort = -1; From 11086791d4cb695094f6d9bdc1c9187f0f211f0a Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 10 Apr 2026 15:40:29 +0200 Subject: [PATCH 09/15] WIP --- .../lambda/LambdaHandlerInstrumentation.java | 22 +- .../java/datadog/trace/core/CoreTracer.java | 3 +- .../trace/lambda/LambdaAppSecHandler.java | 189 ++++++++++++++++++ .../instrumentation/api/AgentTracer.java | 4 +- 4 files changed, 196 insertions(+), 22 deletions(-) diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index 610b24b2f31..f2f1e31d06c 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -24,7 +24,6 @@ import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; import datadog.trace.bootstrap.instrumentation.api.AgentTracer; import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; -import datadog.trace.bootstrap.instrumentation.api.ResourceNamePriorities; import datadog.trace.config.inversion.ConfigHelper; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.type.TypeDescription; @@ -94,9 +93,9 @@ static AgentScope enter( AgentSpanContext lambdaContext = AgentTracer.get().notifyLambdaStart(in, lambdaRequestId); final AgentSpan span; if (null == lambdaContext) { - span = startSpan("java-aws-sdk", INVOCATION_SPAN_NAME); + span = startSpan(INVOCATION_SPAN_NAME); } else { - span = startSpan("java-aws-sdk", INVOCATION_SPAN_NAME, lambdaContext); + span = startSpan(INVOCATION_SPAN_NAME, lambdaContext); } span.setSpanType(InternalSpanTypes.SERVERLESS); span.setTag("request_id", lambdaRequestId); @@ -126,22 +125,7 @@ static void exit( } String lambdaRequestId = awsContext.getAwsRequestId(); - AgentTracer.get().notifyAppSecEnd(span); - // Force the resource name back to the literal placeholder marker right - // before finish so that the Datadog Lambda Extension's filter - // (filter_span_from_lambda_library_or_runtime in - // bottlecap/src/traces/trace_processor.rs, which compares - // span.resource == "dd-tracer-serverless-span") drops the placeholder. - // Other instrumentation (HTTP/JAX-RS) may have overwritten it with the - // route ("POST /") during the invocation, in which case the extension - // would fail to dedup, leading to the placeholder leaking to the backend - // with parent_id=0 and detaching the inferred apigateway root from the - // rest of the trace. - // Use TAG_INTERCEPTOR priority because DDSpanContext.setResourceName - // ignores writes whose priority is below the current resource priority, - // and the HTTP/JAX-RS instrumentation will already have written - // HTTP_FRAMEWORK_ROUTE (3) by this point. - span.setResourceName(INVOCATION_SPAN_NAME, ResourceNamePriorities.TAG_INTERCEPTOR); + AgentTracer.get().notifyAppSecEnd(span, result); span.finish(); AgentTracer.get().notifyExtensionEnd(span, result, null != throwable, lambdaRequestId); } finally { diff --git a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java index 28f9e39c710..247f967cff9 100644 --- a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java +++ b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java @@ -1255,7 +1255,8 @@ public void notifyExtensionEnd( } @Override - public void notifyAppSecEnd(AgentSpan span) { + public void notifyAppSecEnd(AgentSpan span, Object result) { + LambdaAppSecHandler.processResponseData(span, result); LambdaAppSecHandler.processRequestEnd(span); } diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 099527c5382..853be632b4e 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -7,6 +7,7 @@ import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.CallbackProvider; import datadog.trace.api.gateway.Flow; import datadog.trace.api.gateway.IGSpanInfo; import datadog.trace.api.gateway.RequestContext; @@ -21,14 +22,18 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -49,6 +54,10 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); + private static final Set RESPONSE_HEADER_ALLOW_LIST = + new HashSet<>( + Arrays.asList("content-length", "content-type", "content-encoding", "content-language")); + /** * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes * all relevant AppSec gateway callbacks. @@ -105,6 +114,173 @@ public static void processRequestEnd(AgentSpan span) { } } + /** + * Process response data through WAF before the request context is closed. Extracts status code, + * headers, and body from the Lambda response and fires the corresponding gateway events. + * + * @param span the current span + * @param result the Lambda handler result (expected to be a ByteArrayOutputStream) + */ + public static void processResponseData(AgentSpan span, Object result) { + if (!ActiveSubsystems.APPSEC_ACTIVE + || span == null + || !(result instanceof ByteArrayOutputStream)) { + return; + } + + try { + byte[] bytes = ((ByteArrayOutputStream) result).toByteArray(); + if (bytes.length == 0 || bytes.length > MAX_EVENT_SIZE) { + log.debug( + "Response size {} exceeds limit {} or is empty, skipping response processing", + bytes.length, + MAX_EVENT_SIZE); + return; + } + + String json = new String(bytes, StandardCharsets.UTF_8); + LambdaResponseData responseData = extractResponseData(json); + if (responseData == null) { + return; + } + + RequestContext requestContext = span.getRequestContext(); + if (requestContext == null) { + log.debug("Span has no RequestContext, skipping response processing"); + return; + } + + AgentTracer.TracerAPI tracer = AgentTracer.get(); + CallbackProvider cbp = tracer.getCallbackProvider(RequestContextSlot.APPSEC); + + // Fire responseStarted + if (responseData.statusCode > 0) { + BiFunction> responseStartedCb = + cbp.getCallback(EVENTS.responseStarted()); + if (responseStartedCb != null) { + responseStartedCb.apply(requestContext, responseData.statusCode); + } + } + + // Fire responseHeader for each allowed header + if (responseData.headers != null && !responseData.headers.isEmpty()) { + TriConsumer responseHeaderCb = + cbp.getCallback(EVENTS.responseHeader()); + if (responseHeaderCb != null) { + for (Map.Entry header : responseData.headers.entrySet()) { + responseHeaderCb.accept(requestContext, header.getKey(), header.getValue()); + } + } + } + + // Fire responseHeaderDone + Function> responseHeaderDoneCb = + cbp.getCallback(EVENTS.responseHeaderDone()); + if (responseHeaderDoneCb != null) { + responseHeaderDoneCb.apply(requestContext); + } + + // Fire responseBody + if (responseData.body != null) { + BiFunction> responseBodyCb = + cbp.getCallback(EVENTS.responseBody()); + if (responseBodyCb != null) { + responseBodyCb.apply(requestContext, responseData.body); + } + } + } catch (Exception e) { + log.error("Failed to process AppSec response data", e); + } + } + + static LambdaResponseData extractResponseData(String json) { + try { + Map response = MAP_ADAPTER.fromJson(json); + if (response == null) { + return null; + } + + // Extract status code + int statusCode = 0; + Object statusCodeObj = response.get("statusCode"); + if (statusCodeObj instanceof Number) { + statusCode = ((Number) statusCodeObj).intValue(); + } + + // Extract and filter headers + Map headers = new java.util.HashMap<>(); + Map rawHeaders = extractStringMap(response.get("headers")); + + // Merge multiValueHeaders if present (API GW v1 / ALB) + Object multiValueHeadersObj = response.get("multiValueHeaders"); + if (multiValueHeadersObj instanceof Map) { + Map multiValueHeaders = (Map) multiValueHeadersObj; + for (Map.Entry entry : multiValueHeaders.entrySet()) { + if (entry.getKey() != null && entry.getValue() instanceof java.util.List) { + String key = String.valueOf(entry.getKey()); + java.util.List values = (java.util.List) entry.getValue(); + String joinedValue = + values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); + rawHeaders.put(key, joinedValue); + } + } + } + + // Filter to allow-list (case-insensitive) + for (Map.Entry entry : rawHeaders.entrySet()) { + if (RESPONSE_HEADER_ALLOW_LIST.contains(entry.getKey().toLowerCase())) { + headers.put(entry.getKey(), entry.getValue()); + } + } + + // Extract body + Object body = null; + Object bodyObj = response.get("body"); + if (bodyObj != null) { + String bodyString = String.valueOf(bodyObj); + + // Handle base64 encoding + Boolean isBase64Encoded = (Boolean) response.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64Encoded)) { + try { + bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); + } catch (Exception e) { + log.debug("Failed to decode base64 response body", e); + bodyString = null; + } + } + + if (bodyString != null) { + // Determine content-type from response headers + String contentType = null; + for (Map.Entry entry : headers.entrySet()) { + if ("content-type".equalsIgnoreCase(entry.getKey())) { + contentType = entry.getValue(); + break; + } + } + + // If JSON content-type or unknown, attempt JSON parsing + if (contentType == null + || contentType.contains("json") + || contentType.contains("javascript")) { + Object parsed = parseBodyAsJson(bodyString); + body = parsed != null ? parsed : bodyString; + } else { + body = bodyString; + } + } + } + + return new LambdaResponseData(statusCode, headers, body); + } catch (Exception e) { + log.debug("Failed to parse response data from JSON", e); + return null; + } + } + /** * Merge AppSec context data into extension context. * @@ -874,6 +1050,19 @@ static class LambdaEventData { } } + /** Data extracted from a Lambda response for WAF analysis */ + static class LambdaResponseData { + final int statusCode; + final Map headers; + final Object body; + + LambdaResponseData(int statusCode, Map headers, Object body) { + this.statusCode = statusCode; + this.headers = headers; + this.body = body; + } + } + /** URIDataAdapter implementation for Lambda events. */ private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; diff --git a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java index 14a5b5d6d21..8ef7b90a838 100644 --- a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java +++ b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java @@ -378,7 +378,7 @@ default AgentSpan blackholeSpan() { void notifyExtensionEnd(AgentSpan span, Object result, boolean isError, String lambdaRequestId); - void notifyAppSecEnd(AgentSpan span); + void notifyAppSecEnd(AgentSpan span, Object result); AgentDataStreamsMonitoring getDataStreamsMonitoring(); @@ -634,7 +634,7 @@ public void notifyExtensionEnd( AgentSpan span, Object result, boolean isError, String lambdaRequestId) {} @Override - public void notifyAppSecEnd(AgentSpan span) {} + public void notifyAppSecEnd(AgentSpan span, Object result) {} @Override public AgentDataStreamsMonitoring getDataStreamsMonitoring() { From 0967d32e3fb4b59b24435c60af27cc9ff9825c72 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 22 May 2026 14:17:11 +0200 Subject: [PATCH 10/15] wip --- .../HandlerStreamingWith404Response.java | 24 + .../HandlerStreamingWithApiGwResponse.java | 24 + .../LambdaHandlerInstrumentationTest.groovy | 169 +++++ .../trace/lambda/LambdaAppSecHandler.java | 49 +- .../lambda/LambdaAppSecHandlerTest.groovy | 614 ++++++++++++++++++ 5 files changed, 849 insertions(+), 31 deletions(-) create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java new file mode 100644 index 00000000000..540b245e802 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java @@ -0,0 +1,24 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; + +public class HandlerStreamingWith404Response implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + PrintWriter writer = + new PrintWriter( + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8))); + writer.write( + "{\"statusCode\": 404, " + + "\"headers\": {\"content-type\": \"text/html\"}, " + + "\"body\": \"Not Found\"}"); + writer.close(); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java new file mode 100644 index 00000000000..7b91f1f3510 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java @@ -0,0 +1,24 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; + +public class HandlerStreamingWithApiGwResponse implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + PrintWriter writer = + new PrintWriter( + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8))); + writer.write( + "{\"statusCode\": 200, " + + "\"headers\": {\"content-type\": \"application/json\", \"x-custom\": \"custom-val\"}, " + + "\"body\": \"{\\\"result\\\": \\\"ok\\\"}\"}"); + writer.close(); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy index ed1152ea1aa..51fbf25c62c 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy @@ -38,6 +38,12 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase def capturedBody = null def appSecEnded = false + // Response callback capture fields + def capturedResponseStatus = null + def capturedResponseHeaders = [:] + def capturedResponseBody = null + def responseHeaderDoneCalled = false + def setup() { ig = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC) ActiveSubsystems.APPSEC_ACTIVE = true @@ -47,6 +53,12 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase capturedHeaders = [:] capturedBody = null appSecEnded = false + capturedResponseStatus = null + capturedResponseHeaders = [:] + capturedResponseBody = null + responseHeaderDoneCalled = false + + // Request callbacks ig.registerCallback(EVENTS.requestStarted(), { appSecStarted = true new Flow.ResultFlow(new Object()) @@ -70,6 +82,23 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase appSecEnded = true Flow.ResultFlow.empty() } as BiFunction) + + // Response callbacks + ig.registerCallback(EVENTS.responseStarted(), { RequestContext ctx, Integer status -> + capturedResponseStatus = status + Flow.ResultFlow.empty() + } as BiFunction) + ig.registerCallback(EVENTS.responseHeader(), { RequestContext ctx, String name, String value -> + capturedResponseHeaders[name] = value + } as TriConsumer) + ig.registerCallback(EVENTS.responseHeaderDone(), { RequestContext ctx -> + responseHeaderDoneCalled = true + Flow.ResultFlow.empty() + } as Function) + ig.registerCallback(EVENTS.responseBody(), { RequestContext ctx, Object body -> + capturedResponseBody = body + Flow.ResultFlow.empty() + } as BiFunction) } def cleanup() { @@ -256,6 +285,146 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase !appSecStarted capturedMethod == null !appSecEnded + capturedResponseStatus == null + assertTraces(1) { + trace(1) { + span { + operationName operation() + spanType DDSpanTypes.SERVERLESS + errored false + } + } + } + } + + def "response callbacks are invoked for API Gateway v1 response format"() { + given: + def eventJson = """{ + "path": "/api/test", + "headers": {"content-type": "application/json"}, + "requestContext": { + "httpMethod": "GET", + "requestId": "req-resp-1", + "identity": {"sourceIp": "127.0.0.1"} + } + }""" + + when: + def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) + def output = new ByteArrayOutputStream() + def ctx = Stub(Context) { getAwsRequestId() >> requestId } + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, ctx) + + then: + capturedResponseStatus == 200 + capturedResponseHeaders["content-type"] == "application/json" + capturedResponseHeaders["x-custom"] == "custom-val" + capturedResponseBody instanceof Map + capturedResponseBody["result"] == "ok" + responseHeaderDoneCalled + appSecEnded + assertTraces(1) { + trace(1) { + span { + operationName operation() + spanType DDSpanTypes.SERVERLESS + errored false + } + } + } + } + + def "response callbacks receive correct data for 404 response"() { + given: + def eventJson = """{ + "path": "/missing", + "requestContext": { + "httpMethod": "GET", + "requestId": "req-resp-2" + } + }""" + + when: + def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) + def output = new ByteArrayOutputStream() + def ctx = Stub(Context) { getAwsRequestId() >> requestId } + new HandlerStreamingWith404Response().handleRequest(input, output, ctx) + + then: + capturedResponseStatus == 404 + capturedResponseHeaders["content-type"] == "text/html" + capturedResponseBody == "Not Found" // text/html body passed as raw string + appSecEnded + assertTraces(1) { + trace(1) { + span { + operationName operation() + spanType DDSpanTypes.SERVERLESS + errored false + } + } + } + } + + def "response callbacks handle non-API-Gateway response gracefully"() { + when: + def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) + def output = new ByteArrayOutputStream() + def ctx = Stub(Context) { getAwsRequestId() >> requestId } + new HandlerStreaming().handleRequest(input, output, ctx) + + then: + // HandlerStreaming writes "Hello World!" which is not valid API GW JSON + // Response parsing should fail gracefully + capturedResponseStatus == null + capturedResponseHeaders.isEmpty() + capturedResponseBody == null + // requestEnded should still be called + appSecEnded + assertTraces(1) { + trace(1) { + span { + operationName operation() + spanType DDSpanTypes.SERVERLESS + errored false + } + } + } + } + + def "response and request callbacks are both invoked in correct order"() { + given: + def eventJson = """{ + "path": "/api/users/123", + "headers": {"content-type": "application/json"}, + "body": "{\\"key\\": \\"value\\"}", + "requestContext": { + "httpMethod": "POST", + "requestId": "req-order-1", + "identity": {"sourceIp": "10.0.0.1"} + } + }""" + + when: + def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) + def output = new ByteArrayOutputStream() + def ctx = Stub(Context) { getAwsRequestId() >> requestId } + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, ctx) + + then: + // Request callbacks fired + appSecStarted + capturedMethod == "POST" + capturedPath == "/api/users/123" + capturedBody instanceof Map + + // Response callbacks fired + capturedResponseStatus == 200 + capturedResponseHeaders["content-type"] == "application/json" + capturedResponseBody instanceof Map + + // requestEnded fired last + appSecEnded assertTraces(1) { trace(1) { span { diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 853be632b4e..38845471f60 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -4,6 +4,7 @@ import com.squareup.moshi.JsonAdapter; import com.squareup.moshi.Moshi; +import datadog.logging.RatelimitedLogger; import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; @@ -24,16 +25,12 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.InputStreamReader; -import java.io.Reader; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Base64; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -47,6 +44,7 @@ public class LambdaAppSecHandler { private static final Logger log = LoggerFactory.getLogger(LambdaAppSecHandler.class); + private static final RatelimitedLogger rlLog = new RatelimitedLogger(log, 5, TimeUnit.MINUTES); private static final Moshi MOSHI = new Moshi.Builder().build(); private static final JsonAdapter MAP_ADAPTER = MOSHI.adapter(Map.class); @@ -54,10 +52,6 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); - private static final Set RESPONSE_HEADER_ALLOW_LIST = - new HashSet<>( - Arrays.asList("content-length", "content-type", "content-encoding", "content-language")); - /** * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes * all relevant AppSec gateway callbacks. @@ -189,7 +183,7 @@ public static void processResponseData(AgentSpan span, Object result) { } } } catch (Exception e) { - log.error("Failed to process AppSec response data", e); + log.debug("Failed to process AppSec response data", e); } } @@ -207,9 +201,8 @@ static LambdaResponseData extractResponseData(String json) { statusCode = ((Number) statusCodeObj).intValue(); } - // Extract and filter headers - Map headers = new java.util.HashMap<>(); - Map rawHeaders = extractStringMap(response.get("headers")); + // Extract headers + Map headers = extractStringMap(response.get("headers")); // Merge multiValueHeaders if present (API GW v1 / ALB) Object multiValueHeadersObj = response.get("multiValueHeaders"); @@ -223,18 +216,11 @@ static LambdaResponseData extractResponseData(String json) { values.stream() .map(String::valueOf) .collect(java.util.stream.Collectors.joining(", ")); - rawHeaders.put(key, joinedValue); + headers.put(key, joinedValue); } } } - // Filter to allow-list (case-insensitive) - for (Map.Entry entry : rawHeaders.entrySet()) { - if (RESPONSE_HEADER_ALLOW_LIST.contains(entry.getKey().toLowerCase())) { - headers.put(entry.getKey(), entry.getValue()); - } - } - // Extract body Object body = null; Object bodyObj = response.get("body"); @@ -309,7 +295,7 @@ public static AgentSpanContext mergeContexts( return merged; } - log.debug( + rlLog.warn( "Cannot merge AppSec data: extension context is not a TagContext: {}", extensionContext.getClass()); } @@ -436,15 +422,12 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream return LambdaEventData.EMPTY; } - StringBuilder jsonBuilder = new StringBuilder(availableBytes); - try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { - char[] buffer = new char[1024]; - int charsRead; - while ((charsRead = reader.read(buffer)) != -1) { - jsonBuilder.append(buffer, 0, charsRead); - } + byte[] bytes = new byte[availableBytes]; + int read = inputStream.read(bytes); + if (read <= 0) { + return LambdaEventData.EMPTY; } - return extractEventDataFromJson(jsonBuilder.toString()); + return extractEventDataFromJson(new String(bytes, 0, read, StandardCharsets.UTF_8)); } finally { inputStream.reset(); } @@ -678,7 +661,11 @@ private static LambdaEventData extractAlbData( String method = (String) event.get("httpMethod"); String path = (String) event.get("path"); String xff = headers.get("x-forwarded-for"); - String sourceIp = xff != null ? xff.split(",")[0].trim() : null; + String sourceIp = null; + if (xff != null) { + int commaIdx = xff.indexOf(','); + sourceIp = (commaIdx >= 0 ? xff.substring(0, commaIdx) : xff).trim(); + } return new LambdaEventData( headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index 750cfe4ab78..09dc410253c 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -1437,6 +1437,620 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { AgentTracer.forceRegister(mockTracer) } + // ============================================================================ + // processResponseData Tests + // ============================================================================ + + def "processResponseData does nothing when AppSec is disabled"() { + given: + ActiveSubsystems.APPSEC_ACTIVE = false + def span = Mock(AgentSpan) + def result = createOutputStream('{"statusCode": 200, "body": "ok"}') + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + 0 * span._ + } + + def "processResponseData does nothing for null span"() { + given: + def result = createOutputStream('{"statusCode": 200}') + + when: + LambdaAppSecHandler.processResponseData(null, result) + + then: + noExceptionThrown() + } + + def "processResponseData does nothing for non-ByteArrayOutputStream result"() { + given: + def span = Mock(AgentSpan) + + when: + LambdaAppSecHandler.processResponseData(span, "string result") + + then: + 0 * span._ + } + + def "processResponseData does nothing for null result"() { + given: + def span = Mock(AgentSpan) + + when: + LambdaAppSecHandler.processResponseData(span, null) + + then: + 0 * span._ + } + + def "processResponseData does nothing when span has no RequestContext"() { + given: + def span = Mock(AgentSpan) { + getRequestContext() >> null + } + def result = createOutputStream('{"statusCode": 200}') + + setupMockResponseCallbacks([:]) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + noExceptionThrown() + } + + def "processResponseData does nothing for oversized response"() { + given: + def maxSize = Config.get().getAppSecBodyParsingSizeLimit() + def largeBody = "x" * (maxSize + 1) + def result = createOutputStream(largeBody) + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedStatus == null + } + + def "processResponseData does nothing for empty ByteArrayOutputStream"() { + given: + def result = new ByteArrayOutputStream() // 0 bytes + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedStatus == null + } + + // --- Status code extraction --- + + def "processResponseData extracts statusCode correctly"() { + given: + def result = createOutputStream('{"statusCode": 200, "body": "ok"}') + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedStatus == 200 + } + + def "processResponseData extracts statusCode as integer from double"() { + given: + def result = createOutputStream('{"statusCode": 404.0, "body": "not found"}') + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedStatus == 404 + } + + def "processResponseData handles missing statusCode"() { + given: + def result = createOutputStream('{"body": "ok"}') + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedStatus == null // statusCode is 0, so responseStarted is not called + } + + def "processResponseData handles non-numeric statusCode"() { + given: + def result = createOutputStream('{"statusCode": "bad", "body": "ok"}') + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + noExceptionThrown() + capturedStatus == null // "bad" is not a Number, statusCode stays 0 + } + + // --- Header extraction --- + + def "processResponseData forwards all response headers"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "application/json", "x-custom": "val", "content-length": "42", "set-cookie": "a=1"}}' + def result = createOutputStream(json) + def capturedHeaders = [:] + + def span = setupMockResponseCallbacks( + onResponseHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedHeaders.size() == 4 + capturedHeaders["content-type"] == "application/json" + capturedHeaders["x-custom"] == "val" + capturedHeaders["content-length"] == "42" + capturedHeaders["set-cookie"] == "a=1" + } + + def "processResponseData preserves original header casing"() { + given: + def json = '{"statusCode": 200, "headers": {"Content-Type": "text/html", "CONTENT-LENGTH": "10"}}' + def result = createOutputStream(json) + def capturedHeaders = [:] + + def span = setupMockResponseCallbacks( + onResponseHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedHeaders["Content-Type"] == "text/html" + capturedHeaders["CONTENT-LENGTH"] == "10" + } + + def "processResponseData merges multiValueHeaders with single-value headers"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "text/html"}, "multiValueHeaders": {"content-encoding": ["gzip", "br"]}}' + def result = createOutputStream(json) + def capturedHeaders = [:] + + def span = setupMockResponseCallbacks( + onResponseHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedHeaders["content-type"] == "text/html" + capturedHeaders["content-encoding"] == "gzip, br" + } + + def "processResponseData handles empty headers"() { + given: + def result = createOutputStream('{"statusCode": 200}') + def capturedHeaders = [:] + def headerDoneCalled = false + + def span = setupMockResponseCallbacks( + onResponseHeader: { name, value -> capturedHeaders[name] = value }, + onResponseHeaderDone: { + headerDoneCalled = true + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedHeaders.isEmpty() + headerDoneCalled + } + + // --- Body extraction --- + + def "processResponseData parses JSON body"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "application/json"}, "body": "{\\"key\\": \\"value\\"}"}' + def result = createOutputStream(json) + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody instanceof Map + capturedBody["key"] == "value" + } + + def "processResponseData handles non-JSON body as raw string"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "text/plain"}, "body": "plain text"}' + def result = createOutputStream(json) + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody == "plain text" + } + + def "processResponseData handles base64 encoded body"() { + given: + def originalBody = '{"decoded": "content"}' + def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes(StandardCharsets.UTF_8)) + def json = """{"statusCode": 200, "body": "${base64Body}", "isBase64Encoded": true}""" + def result = createOutputStream(json) + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody instanceof Map + capturedBody["decoded"] == "content" + } + + def "processResponseData handles null body"() { + given: + def result = createOutputStream('{"statusCode": 200, "body": null}') + def capturedBody = "NOT_CALLED" + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody == "NOT_CALLED" + } + + def "processResponseData handles missing body field"() { + given: + def result = createOutputStream('{"statusCode": 200}') + def capturedBody = "NOT_CALLED" + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody == "NOT_CALLED" + } + + def "processResponseData attempts JSON parse when no content-type"() { + given: + def result = createOutputStream('{"statusCode": 200, "body": "{\\"a\\": 1}"}') + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody instanceof Map + capturedBody["a"] == 1.0d // Moshi parses numbers as Double + } + + def "processResponseData falls back to raw string when JSON parse fails"() { + given: + def result = createOutputStream('{"statusCode": 200, "body": "not json {"}') + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody == "not json {" + } + + // --- Event ordering --- + + def "processResponseData fires events in correct order"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "application/json"}, "body": "{\\"k\\": \\"v\\"}"}' + def result = createOutputStream(json) + def order = [] + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> order << "responseStarted" }, + onResponseHeader: { name, value -> order << "responseHeader" }, + onResponseHeaderDone: { order << "responseHeaderDone" }, + onResponseBody: { body -> + order << "responseBody" + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + order[0] == "responseStarted" + order.findAll { it == "responseHeader" }.size() >= 1 + def headerDoneIdx = order.indexOf("responseHeaderDone") + def lastHeaderIdx = order.lastIndexOf("responseHeader") + headerDoneIdx > lastHeaderIdx + order.last() == "responseBody" + } + + def "processResponseData handles invalid base64 response body gracefully"() { + given: + def json = '{"statusCode": 200, "body": "not-valid-base64!!!", "isBase64Encoded": true}' + def result = createOutputStream(json) + def capturedBody = "NOT_CALLED" + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + noExceptionThrown() + capturedBody == "NOT_CALLED" + } + + def "processResponseData parses body as JSON for javascript content-type"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "application/javascript"}, "body": "{\\"key\\": \\"val\\"}"}' + def result = createOutputStream(json) + def capturedBody = null + + def span = setupMockResponseCallbacks( + onResponseBody: { body -> + capturedBody = body + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedBody instanceof Map + capturedBody["key"] == "val" + } + + def "processResponseData multiValueHeaders override single-value headers"() { + given: + def json = '{"statusCode": 200, "headers": {"content-type": "text/html"}, "multiValueHeaders": {"content-type": ["application/json", "charset=utf-8"]}}' + def result = createOutputStream(json) + def capturedHeaders = [:] + + def span = setupMockResponseCallbacks( + onResponseHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + capturedHeaders["content-type"] == "application/json, charset=utf-8" + } + + // --- Error handling --- + + def "processResponseData handles malformed JSON response"() { + given: + def result = createOutputStream('{not valid json') + def capturedStatus = null + + def span = setupMockResponseCallbacks( + onResponseStarted: { status -> + capturedStatus = status + } + ) + + when: + LambdaAppSecHandler.processResponseData(span, result) + + then: + noExceptionThrown() + capturedStatus == null + } + + def "processResponseData handles empty string response"() { + given: + // Empty string inside a non-empty BAOS (the string "") + // This would fail JSON parsing + def result = createOutputStream('') + + when: + LambdaAppSecHandler.processResponseData(Mock(AgentSpan), result) + + then: + noExceptionThrown() + } + + // ============================================================================ + // extractResponseData Unit Tests + // ============================================================================ + + def "extractResponseData returns null for malformed JSON"() { + when: + def result = LambdaAppSecHandler.extractResponseData('{bad json') + + then: + result == null + } + + def "extractResponseData returns null for null JSON parse result"() { + when: + def result = LambdaAppSecHandler.extractResponseData('null') + + then: + result == null + } + + // ============================================================================ + // Helper Methods + // ============================================================================ + + private ByteArrayOutputStream createOutputStream(String json) { + def baos = new ByteArrayOutputStream() + baos.write(json.getBytes(StandardCharsets.UTF_8)) + return baos + } + + /** + * Set up mock response callbacks and return a mock span with a valid RequestContext. + * processResponseData uses span.getRequestContext() to get the RequestContext, + * unlike processRequestStart which uses TemporaryRequestContext. + */ + private AgentSpan setupMockResponseCallbacks(Map callbacks) { + def mockAppSecContext = new Object() + def mockRequestContext = Mock(RequestContext) { + getData(RequestContextSlot.APPSEC) >> mockAppSecContext + } + def mockSpan = Mock(AgentSpan) { + getRequestContext() >> mockRequestContext + } + + def mockResponseStartedCb = callbacks.onResponseStarted ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Integer) >> { + RequestContext ctx, Integer status -> + callbacks.onResponseStarted(status) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockResponseHeaderCb = callbacks.onResponseHeader ? Mock(TriConsumer) { + accept(_ as RequestContext, _ as String, _ as String) >> { + RequestContext ctx, String name, String value -> + callbacks.onResponseHeader(name, value) + } + } : null + + def mockResponseHeaderDoneCb = callbacks.onResponseHeaderDone ? Mock(Function) { + apply(_ as RequestContext) >> { + callbacks.onResponseHeaderDone() + return new Flow.ResultFlow<>(null) + } + } : Mock(Function) { + apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) + } + + def mockResponseBodyCb = callbacks.onResponseBody ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Object) >> { + RequestContext ctx, Object body -> + callbacks.onResponseBody(body) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.responseStarted()) >> mockResponseStartedCb + getCallback(EVENTS.responseHeader()) >> mockResponseHeaderCb + getCallback(EVENTS.responseHeaderDone()) >> mockResponseHeaderDoneCb + getCallback(EVENTS.responseBody()) >> mockResponseBodyCb + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + return mockSpan + } + def cleanup() { // Reset tracer after each test AgentTracer.forceRegister(null) From 65bb8442cba80267b961399869c8337b9e3888c3 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 26 Jun 2026 10:07:34 +0200 Subject: [PATCH 11/15] AppSec Lambda: HTTP trigger type detection + resource name fix for extension dedup Co-Authored-By: Claude Sonnet 4.6 --- .../lambda/LambdaHandlerInstrumentation.java | 7 + .../LambdaHandlerInstrumentationTest.groovy | 463 ---- .../{groovy => java}/HandlerStreaming.java | 0 ...reamingSimulatesHttpFrameworkResource.java | 0 .../HandlerStreamingWith404Response.java | 0 .../HandlerStreamingWithApiGwResponse.java | 0 .../HandlerStreamingWithError.java | 0 .../java/HandlerStreamingWithRawJson.java | 15 + .../LambdaHandlerInstrumentationTest.java | 576 +++++ .../LambdaHandlerInstrumentationV0Test.java | 15 + ...bdaHandlerInstrumentationV1ForkedTest.java | 15 + .../trace/lambda/LambdaAppSecHandler.java | 106 +- .../lambda/LambdaAppSecHandlerTest.groovy | 2058 ----------------- .../trace/lambda/LambdaAppSecHandlerTest.java | 1672 ++++++++----- 14 files changed, 1767 insertions(+), 3160 deletions(-) delete mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy rename dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/{groovy => java}/HandlerStreaming.java (100%) rename dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/{groovy => java}/HandlerStreamingSimulatesHttpFrameworkResource.java (100%) rename dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/{groovy => java}/HandlerStreamingWith404Response.java (100%) rename dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/{groovy => java}/HandlerStreamingWithApiGwResponse.java (100%) rename dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/{groovy => java}/HandlerStreamingWithError.java (100%) create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java create mode 100644 dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java delete mode 100644 dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index f2f1e31d06c..2c1370a3a80 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -24,6 +24,7 @@ import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; import datadog.trace.bootstrap.instrumentation.api.AgentTracer; import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; +import datadog.trace.bootstrap.instrumentation.api.ResourceNamePriorities; import datadog.trace.config.inversion.ConfigHelper; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.type.TypeDescription; @@ -126,6 +127,12 @@ static void exit( String lambdaRequestId = awsContext.getAwsRequestId(); AgentTracer.get().notifyAppSecEnd(span, result); + // Force the resource name back to the literal placeholder marker right + // before finish so that the Datadog Lambda Extension's filter + // (filter_span_from_lambda_library_or_runtime in + // bottlecap/src/traces/trace_processor.rs, which compares + // span.resource == "dd-tracer-serverless-span") drops the placeholder. + span.setResourceName(INVOCATION_SPAN_NAME, ResourceNamePriorities.TAG_INTERCEPTOR); span.finish(); AgentTracer.get().notifyExtensionEnd(span, result, null != throwable, lambdaRequestId); } finally { diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy deleted file mode 100644 index 51fbf25c62c..00000000000 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy +++ /dev/null @@ -1,463 +0,0 @@ -import static datadog.trace.api.gateway.Events.EVENTS - -import datadog.trace.agent.test.naming.VersionedNamingTestBase -import datadog.trace.api.DDSpanTypes -import datadog.trace.api.function.TriConsumer -import datadog.trace.api.function.TriFunction -import datadog.trace.api.gateway.Flow -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.gateway.RequestContextSlot -import datadog.trace.bootstrap.ActiveSubsystems -import datadog.trace.bootstrap.instrumentation.api.AgentTracer -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter -import com.amazonaws.services.lambda.runtime.Context -import java.nio.charset.StandardCharsets -import java.util.function.BiFunction -import java.util.function.Function -import java.util.function.Supplier - -abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase { - def requestId = "test-request-id" - - // Must set this env var before the Datadog integration is initialized. - // If present at load time, the integration auto-enables. - static { - environmentVariables.set("_HANDLER", "Handler") - } - - @Override - String service() { - null - } - - def ig - def appSecStarted = false - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedBody = null - def appSecEnded = false - - // Response callback capture fields - def capturedResponseStatus = null - def capturedResponseHeaders = [:] - def capturedResponseBody = null - def responseHeaderDoneCalled = false - - def setup() { - ig = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC) - ActiveSubsystems.APPSEC_ACTIVE = true - appSecStarted = false - capturedMethod = null - capturedPath = null - capturedHeaders = [:] - capturedBody = null - appSecEnded = false - capturedResponseStatus = null - capturedResponseHeaders = [:] - capturedResponseBody = null - responseHeaderDoneCalled = false - - // Request callbacks - ig.registerCallback(EVENTS.requestStarted(), { - appSecStarted = true - new Flow.ResultFlow(new Object()) - } as Supplier) - ig.registerCallback(EVENTS.requestMethodUriRaw(), { RequestContext ctx, String method, URIDataAdapter uri -> - capturedMethod = method - capturedPath = uri.path() - Flow.ResultFlow.empty() - } as TriFunction) - ig.registerCallback(EVENTS.requestHeader(), { RequestContext ctx, String name, String value -> - capturedHeaders[name] = value - } as TriConsumer) - ig.registerCallback(EVENTS.requestHeaderDone(), { RequestContext ctx -> - Flow.ResultFlow.empty() - } as Function) - ig.registerCallback(EVENTS.requestBodyProcessed(), { RequestContext ctx, Object body -> - capturedBody = body - Flow.ResultFlow.empty() - } as BiFunction) - ig.registerCallback(EVENTS.requestEnded(), { RequestContext ctx, Object spanInfo -> - appSecEnded = true - Flow.ResultFlow.empty() - } as BiFunction) - - // Response callbacks - ig.registerCallback(EVENTS.responseStarted(), { RequestContext ctx, Integer status -> - capturedResponseStatus = status - Flow.ResultFlow.empty() - } as BiFunction) - ig.registerCallback(EVENTS.responseHeader(), { RequestContext ctx, String name, String value -> - capturedResponseHeaders[name] = value - } as TriConsumer) - ig.registerCallback(EVENTS.responseHeaderDone(), { RequestContext ctx -> - responseHeaderDoneCalled = true - Flow.ResultFlow.empty() - } as Function) - ig.registerCallback(EVENTS.responseBody(), { RequestContext ctx, Object body -> - capturedResponseBody = body - Flow.ResultFlow.empty() - } as BiFunction) - } - - def cleanup() { - ig.reset() - ActiveSubsystems.APPSEC_ACTIVE = false - } - - def "test lambda streaming handler"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "serverless invocation span resource reset after simulated HTTP framework overwrite"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreamingSimulatesHttpFrameworkResource().handleRequest(input, output, ctx) - - then: - assertTraces(1) { - trace(1) { - span { - operationName operation() - resourceName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "test streaming handler with error"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreamingWithError().handleRequest(input, output, ctx) - - then: - thrown(Error) - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored true - tags { - tag "request_id", requestId - tag "error.type", "java.lang.Error" - tag "error.message", "Some error" - tag "error.stack", String - tag "language", "jvm" - tag "process_id", Long - tag "runtime-id", String - tag "thread.id", Long - tag "thread.name", String - tag "_dd.profiling.ctx", "test" - tag "_dd.profiling.enabled", 0 - tag "_dd.agent_psr", 1.0 - tag "_dd.tracer_host", String - tag "_sample_rate", 1 - tag "_dd.trace_span_attribute_schema", { it != null } - } - } - } - } - } - - def "appsec callbacks are invoked for API Gateway v1 event"() { - given: - def eventJson = """{ - "path": "/api/users/123", - "headers": {"content-type": "application/json", "x-forwarded-for": "203.0.113.1"}, - "body": "{\\"key\\": \\"value\\"}", - "requestContext": { - "httpMethod": "GET", - "requestId": "req-abc", - "identity": {"sourceIp": "203.0.113.1"} - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - appSecStarted - capturedMethod == "GET" - capturedPath == "/api/users/123" - capturedHeaders["content-type"] == "application/json" - capturedBody instanceof Map - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "appsec callbacks are invoked for API Gateway v2 HTTP event"() { - given: - def eventJson = """{ - "version": "2.0", - "headers": {"content-type": "application/json", "accept": "application/json"}, - "cookies": ["session=abc123"], - "body": "{\\"key\\": \\"value\\"}", - "requestContext": { - "http": { - "method": "POST", - "path": "/api/items", - "sourceIp": "198.51.100.1" - }, - "domainName": "api.example.com" - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - appSecStarted - capturedMethod == "POST" - capturedPath == "/api/items" - capturedHeaders["content-type"] == "application/json" - capturedHeaders["cookie"] == "session=abc123" - capturedBody instanceof Map - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "appsec callbacks are not invoked when appsec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - - when: - def eventJson = """{ - "path": "/api/test", - "requestContext": {"httpMethod": "GET", "requestId": "req-xyz"} - }""" - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - !appSecStarted - capturedMethod == null - !appSecEnded - capturedResponseStatus == null - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "response callbacks are invoked for API Gateway v1 response format"() { - given: - def eventJson = """{ - "path": "/api/test", - "headers": {"content-type": "application/json"}, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-resp-1", - "identity": {"sourceIp": "127.0.0.1"} - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreamingWithApiGwResponse().handleRequest(input, output, ctx) - - then: - capturedResponseStatus == 200 - capturedResponseHeaders["content-type"] == "application/json" - capturedResponseHeaders["x-custom"] == "custom-val" - capturedResponseBody instanceof Map - capturedResponseBody["result"] == "ok" - responseHeaderDoneCalled - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "response callbacks receive correct data for 404 response"() { - given: - def eventJson = """{ - "path": "/missing", - "requestContext": { - "httpMethod": "GET", - "requestId": "req-resp-2" - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreamingWith404Response().handleRequest(input, output, ctx) - - then: - capturedResponseStatus == 404 - capturedResponseHeaders["content-type"] == "text/html" - capturedResponseBody == "Not Found" // text/html body passed as raw string - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "response callbacks handle non-API-Gateway response gracefully"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - // HandlerStreaming writes "Hello World!" which is not valid API GW JSON - // Response parsing should fail gracefully - capturedResponseStatus == null - capturedResponseHeaders.isEmpty() - capturedResponseBody == null - // requestEnded should still be called - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "response and request callbacks are both invoked in correct order"() { - given: - def eventJson = """{ - "path": "/api/users/123", - "headers": {"content-type": "application/json"}, - "body": "{\\"key\\": \\"value\\"}", - "requestContext": { - "httpMethod": "POST", - "requestId": "req-order-1", - "identity": {"sourceIp": "10.0.0.1"} - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreamingWithApiGwResponse().handleRequest(input, output, ctx) - - then: - // Request callbacks fired - appSecStarted - capturedMethod == "POST" - capturedPath == "/api/users/123" - capturedBody instanceof Map - - // Response callbacks fired - capturedResponseStatus == 200 - capturedResponseHeaders["content-type"] == "application/json" - capturedResponseBody instanceof Map - - // requestEnded fired last - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } -} - - -class LambdaHandlerInstrumentationV0Test extends LambdaHandlerInstrumentationTest { - @Override - int version() { - 0 - } - - @Override - String operation() { - "dd-tracer-serverless-span" - } -} - -class LambdaHandlerInstrumentationV1ForkedTest extends LambdaHandlerInstrumentationTest { - @Override - int version() { - 1 - } - - @Override - String operation() { - "aws.lambda.invoke" - } -} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreaming.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreaming.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreaming.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreaming.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingSimulatesHttpFrameworkResource.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingSimulatesHttpFrameworkResource.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingSimulatesHttpFrameworkResource.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingSimulatesHttpFrameworkResource.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWith404Response.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWith404Response.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWith404Response.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithApiGwResponse.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithApiGwResponse.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithApiGwResponse.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithError.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithError.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithError.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithError.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java new file mode 100644 index 00000000000..a34a2623770 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java @@ -0,0 +1,15 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; + +/** Writes valid JSON that is not in API Gateway response format (no statusCode/headers/body). */ +public class HandlerStreamingWithRawJson implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + outputStream.write("{\"result\": \"hello\"}".getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java new file mode 100644 index 00000000000..9ab0606f2d2 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java @@ -0,0 +1,576 @@ +import static datadog.trace.agent.test.assertions.Matchers.is; +import static datadog.trace.agent.test.assertions.SpanMatcher.span; +import static datadog.trace.agent.test.assertions.TagsMatcher.defaultTags; +import static datadog.trace.agent.test.assertions.TagsMatcher.error; +import static datadog.trace.agent.test.assertions.TagsMatcher.tag; +import static datadog.trace.agent.test.assertions.TraceMatcher.trace; +import static datadog.trace.api.gateway.Events.EVENTS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazonaws.services.lambda.runtime.ClientContext; +import com.amazonaws.services.lambda.runtime.CognitoIdentity; +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.LambdaLogger; +import datadog.trace.agent.test.AbstractInstrumentationTest; +import datadog.trace.api.DDSpanTypes; +import datadog.trace.api.function.TriConsumer; +import datadog.trace.api.function.TriFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.IGSpanInfo; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.gateway.SubscriptionService; +import datadog.trace.bootstrap.ActiveSubsystems; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; +import datadog.trace.junit.utils.config.WithConfig; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@WithConfig(key = "_HANDLER", value = "Handler", env = true, addPrefix = false) +abstract class LambdaHandlerInstrumentationTest extends AbstractInstrumentationTest { + + static final String REQUEST_ID = "test-request-id"; + + // Object to avoid bootstrap class in field type (TestClassShadowingExtension check) + Object ig; + + boolean appSecStarted; + String capturedMethod; + String capturedPath; + Map capturedHeaders; + Object capturedBody; + boolean appSecEnded; + + Integer capturedResponseStatus; + Map capturedResponseHeaders; + Object capturedResponseBody; + boolean responseHeaderDoneCalled; + + abstract int version(); + + abstract String operation(); + + @BeforeEach + void setUpAppSec() { + SubscriptionService ss = + (SubscriptionService) AgentTracer.get().getSubscriptionService(RequestContextSlot.APPSEC); + ig = ss; + ActiveSubsystems.APPSEC_ACTIVE = true; + + appSecStarted = false; + capturedMethod = null; + capturedPath = null; + capturedHeaders = new HashMap<>(); + capturedBody = null; + appSecEnded = false; + capturedResponseStatus = null; + capturedResponseHeaders = new HashMap<>(); + capturedResponseBody = null; + responseHeaderDoneCalled = false; + + ss.registerCallback( + EVENTS.requestStarted(), + (Supplier>) + () -> { + appSecStarted = true; + return new Flow.ResultFlow<>(new Object()); + }); + ss.registerCallback( + EVENTS.requestMethodUriRaw(), + (TriFunction>) + (ctx2, method2, uri) -> { + capturedMethod = method2; + capturedPath = uri.path(); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestHeader(), + (TriConsumer) + (ctx2, name, value) -> capturedHeaders.put(name, value)); + ss.registerCallback( + EVENTS.requestHeaderDone(), + (Function>) ctx2 -> Flow.ResultFlow.empty()); + ss.registerCallback( + EVENTS.requestBodyProcessed(), + (BiFunction>) + (ctx2, body) -> { + capturedBody = body; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestEnded(), + (BiFunction>) + (ctx2, spanInfo) -> { + appSecEnded = true; + return Flow.ResultFlow.empty(); + }); + + ss.registerCallback( + EVENTS.responseStarted(), + (BiFunction>) + (ctx2, status) -> { + capturedResponseStatus = status; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseHeader(), + (TriConsumer) + (ctx2, name, value) -> capturedResponseHeaders.put(name, value)); + ss.registerCallback( + EVENTS.responseHeaderDone(), + (Function>) + ctx2 -> { + responseHeaderDoneCalled = true; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseBody(), + (BiFunction>) + (ctx2, body) -> { + capturedResponseBody = body; + return Flow.ResultFlow.empty(); + }); + } + + @AfterEach + void cleanUpAppSec() { + ((SubscriptionService) ig).reset(); + ActiveSubsystems.APPSEC_ACTIVE = false; + } + + private Context newContext() { + return new TestContext(REQUEST_ID); + } + + @Test + void testLambdaStreamingHandler() throws IOException { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void serverlessInvocationSpanResourceResetAfterHttpFrameworkOverwrite() throws IOException { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingSimulatesHttpFrameworkResource().handleRequest(input, output, newContext()); + + assertTraces( + trace( + span() + .resourceName(name -> operation().equals(name.toString())) + .type(DDSpanTypes.SERVERLESS) + .error(false))); + } + + @Test + void testStreamingHandlerWithError() { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + + assertThrows( + Error.class, + () -> new HandlerStreamingWithError().handleRequest(input, output, newContext())); + + assertTraces( + trace( + span() + .type(DDSpanTypes.SERVERLESS) + .error(true) + .tags( + defaultTags(), + tag("request_id", is(REQUEST_ID)), + error(Error.class, "Some error")))); + } + + @Test + void appSecCallbacksAreInvokedForApiGatewayV1Event() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/users/123\"," + + "\"headers\": {\"content-type\": \"application/json\"," + + " \"x-forwarded-for\": \"203.0.113.1\"}," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"requestId\": \"req-abc\"," + + " \"identity\": {\"sourceIp\": \"203.0.113.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("GET", capturedMethod); + assertEquals("/api/users/123", capturedPath); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertTrue(capturedBody instanceof Map); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void appSecCallbacksAreInvokedForApiGatewayV2HttpEvent() throws IOException { + String eventJson = + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"content-type\": \"application/json\"," + + " \"accept\": \"application/json\"}," + + "\"cookies\": [\"session=abc123\"]," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"http\": {" + + " \"method\": \"POST\"," + + " \"path\": \"/api/items\"," + + " \"sourceIp\": \"198.51.100.1\"" + + " }," + + " \"domainName\": \"api.example.com\"" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("POST", capturedMethod); + assertEquals("/api/items", capturedPath); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertEquals("session=abc123", capturedHeaders.get("cookie")); + assertTrue(capturedBody instanceof Map); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void appSecCallbacksAreNotInvokedWhenAppSecIsDisabled() throws IOException { + ActiveSubsystems.APPSEC_ACTIVE = false; + + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-xyz\"}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertFalse(appSecStarted); + assertNull(capturedMethod); + assertFalse(appSecEnded); + assertNull(capturedResponseStatus); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksAreInvokedForJsonEncodedResponse() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"identity\": {\"sourceIp\": \"127.0.0.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertEquals(200, (int) capturedResponseStatus); + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertEquals("custom-val", capturedResponseHeaders.get("x-custom")); + assertTrue(capturedResponseBody instanceof Map); + assertEquals("ok", ((Map) capturedResponseBody).get("result")); + assertTrue(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksReceiveCorrectDataFor404Response() throws IOException { + String eventJson = + "{" + + "\"path\": \"/missing\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWith404Response().handleRequest(input, output, newContext()); + + assertEquals(404, (int) capturedResponseStatus); + assertEquals("text/html", capturedResponseHeaders.get("content-type")); + assertEquals("Not Found", capturedResponseBody); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksApplyFallbackForLambdaUrlWithNonApiGatewayResponse() throws IOException { + // A Lambda Function URL handler returning plain JSON (no statusCode/headers/body structure) + // should trigger the fallback: no responseStarted (status unknown), content-type: + // application/json, full JSON as body. + String eventJson = + "{" + + "\"version\": \"2.0\"," + + "\"rawPath\": \"/\"," + + "\"headers\": {\"host\": \"example.lambda-url.us-east-1.on.aws\"}," + + "\"requestContext\": {" + + " \"domainName\": \"example.lambda-url.us-east-1.on.aws\"," + + " \"http\": {" + + " \"method\": \"GET\"," + + " \"path\": \"/\"," + + " \"sourceIp\": \"1.2.3.4\"" + + " }" + + "}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithRawJson().handleRequest(input, output, newContext()); + + assertNull(capturedResponseStatus); // no responseStarted for status-less fallback + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertTrue(capturedResponseBody instanceof Map); + assertEquals("hello", ((Map) capturedResponseBody).get("result")); + assertTrue(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksSkipNonApiGatewayResponseForNonHttpEvent() throws IOException { + // When the trigger type cannot be determined (non-JSON or non-HTTP event), + // response callbacks must not fire even if the response is valid JSON. + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithRawJson().handleRequest(input, output, newContext()); + + assertNull(capturedResponseStatus); + assertTrue(capturedResponseHeaders.isEmpty()); + assertNull(capturedResponseBody); + assertFalse(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseAndRequestCallbacksAreBothInvoked() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/users/123\"," + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"POST\"," + + " \"requestId\": \"req-order-1\"," + + " \"identity\": {\"sourceIp\": \"10.0.0.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("POST", capturedMethod); + assertEquals("/api/users/123", capturedPath); + assertTrue(capturedBody instanceof Map); + + assertEquals(200, (int) capturedResponseStatus); + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertTrue(capturedResponseBody instanceof Map); + + assertTrue(appSecEnded); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksFireBeforeRequestEnded() throws IOException { + List callOrder = new ArrayList<>(); + + // Reset and re-register to capture ordering + SubscriptionService ss = (SubscriptionService) ig; + ss.reset(); + + ss.registerCallback( + EVENTS.requestStarted(), + (Supplier>) () -> new Flow.ResultFlow<>(new Object())); + ss.registerCallback( + EVENTS.responseStarted(), + (BiFunction>) + (ctx2, status) -> { + callOrder.add("responseStarted"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseHeaderDone(), + (Function>) + ctx2 -> { + callOrder.add("responseHeaderDone"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseBody(), + (BiFunction>) + (ctx2, body) -> { + callOrder.add("responseBody"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestEnded(), + (BiFunction>) + (ctx2, spanInfo) -> { + callOrder.add("requestEnded"); + return Flow.ResultFlow.empty(); + }); + + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"requestId\": \"req-order-2\"" + + "}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertTrue(callOrder.contains("responseStarted")); + assertTrue(callOrder.contains("responseHeaderDone")); + assertTrue(callOrder.contains("responseBody")); + assertTrue(callOrder.contains("requestEnded")); + assertTrue(callOrder.indexOf("responseStarted") < callOrder.indexOf("requestEnded")); + assertTrue(callOrder.indexOf("responseHeaderDone") < callOrder.indexOf("requestEnded")); + assertTrue(callOrder.indexOf("responseBody") < callOrder.indexOf("requestEnded")); + assertTraces( + trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksReceiveNoDataWhenHandlerThrows() { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + + assertThrows( + Error.class, + () -> new HandlerStreamingWithError().handleRequest(input, output, newContext())); + + assertNull(capturedResponseStatus, "response status should not be set when handler throws"); + assertNull(capturedResponseBody, "response body should not be set when handler throws"); + assertTraces( + trace( + span() + .type(DDSpanTypes.SERVERLESS) + .error(true) + .tags( + defaultTags(), + tag("request_id", is(REQUEST_ID)), + error(Error.class, "Some error")))); + } + + private static final class TestContext implements Context { + private final String requestId; + + TestContext(String requestId) { + this.requestId = requestId; + } + + @Override + public String getAwsRequestId() { + return requestId; + } + + @Override + public String getLogGroupName() { + return null; + } + + @Override + public String getLogStreamName() { + return null; + } + + @Override + public String getFunctionName() { + return null; + } + + @Override + public String getFunctionVersion() { + return null; + } + + @Override + public String getInvokedFunctionArn() { + return null; + } + + @Override + public CognitoIdentity getIdentity() { + return null; + } + + @Override + public ClientContext getClientContext() { + return null; + } + + @Override + public int getRemainingTimeInMillis() { + return 0; + } + + @Override + public int getMemoryLimitInMB() { + return 0; + } + + @Override + public LambdaLogger getLogger() { + return null; + } + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java new file mode 100644 index 00000000000..2c669643a7e --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java @@ -0,0 +1,15 @@ +import datadog.trace.junit.utils.config.WithConfig; + +@WithConfig(key = "trace.span.attribute.schema", value = "v0") +class LambdaHandlerInstrumentationV0Test extends LambdaHandlerInstrumentationTest { + + @Override + int version() { + return 0; + } + + @Override + String operation() { + return "dd-tracer-serverless-span"; + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java new file mode 100644 index 00000000000..c37027b64e2 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java @@ -0,0 +1,15 @@ +import datadog.trace.junit.utils.config.WithConfig; + +@WithConfig(key = "trace.span.attribute.schema", value = "v1") +class LambdaHandlerInstrumentationV1ForkedTest extends LambdaHandlerInstrumentationTest { + + @Override + int version() { + return 1; + } + + @Override + String operation() { + return "aws.lambda.invoke"; + } +} diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 38845471f60..1c0b1112df2 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -26,14 +26,18 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Base64; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,6 +56,10 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); + // Carries the detected trigger type from processRequestStart to processResponseData within the + // same Lambda invocation. Cleared in processRequestEnd. + private static final ThreadLocal CURRENT_TRIGGER_TYPE = new ThreadLocal<>(); + /** * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes * all relevant AppSec gateway callbacks. @@ -66,6 +74,8 @@ public static AgentSpanContext processRequestStart(Object event) { return null; } + CURRENT_TRIGGER_TYPE.set(LambdaTriggerType.UNKNOWN); + if (!(event instanceof ByteArrayInputStream)) { log.debug( "Event is not a ByteArrayInputStream, type: {}", @@ -78,6 +88,7 @@ public static AgentSpanContext processRequestStart(Object event) { if (eventData == LambdaEventData.EMPTY) { return null; } + CURRENT_TRIGGER_TYPE.set(eventData.triggerType); return processAppSecRequestData(eventData); } catch (Exception e) { log.debug("Failed to process AppSec request data", e); @@ -91,6 +102,8 @@ public static AgentSpanContext processRequestStart(Object event) { * @param span the current span */ public static void processRequestEnd(AgentSpan span) { + CURRENT_TRIGGER_TYPE.remove(); + if (!ActiveSubsystems.APPSEC_ACTIVE || span == null) { return; } @@ -134,8 +147,32 @@ public static void processResponseData(AgentSpan span, Object result) { String json = new String(bytes, StandardCharsets.UTF_8); LambdaResponseData responseData = extractResponseData(json); - if (responseData == null) { - return; + + if (responseData == null || responseData.statusCode == 0) { + // No statusCode means this is not an API-GW formatted response, or JSON parsing failed. + // Only process for known HTTP trigger types, mirroring Python's asm_start_response. + LambdaTriggerType triggerType = CURRENT_TRIGGER_TYPE.get(); + if (triggerType == null || !triggerType.isHttp()) { + return; + } + if (responseData == null || (responseData.headers.isEmpty() && responseData.body == null)) { + // Parse failed or response has no API-GW structure (plain JSON body). + // Treat the full response as the body, mirroring Python's asm_start_response. + Object fallbackBody; + String fallbackContentType; + try { + fallbackBody = OBJECT_ADAPTER.fromJson(json); + fallbackContentType = "application/json"; + } catch (Exception e) { + fallbackBody = json; + fallbackContentType = "text/plain"; + } + Map fallbackHeaders = + Collections.singletonMap("content-type", fallbackContentType); + responseData = new LambdaResponseData(0, fallbackHeaders, fallbackBody); + } + // else: responseData has explicit headers/body fields β€” keep them, just skip responseStarted + // (statusCode remains 0, so the responseStarted guard below will not fire). } RequestContext requestContext = span.getRequestContext(); @@ -147,6 +184,9 @@ public static void processResponseData(AgentSpan span, Object result) { AgentTracer.TracerAPI tracer = AgentTracer.get(); CallbackProvider cbp = tracer.getCallbackProvider(RequestContextSlot.APPSEC); + // Fire response gateway events. Flow results are intentionally ignored: blocking on response + // is not supported for Lambda because remote config is unavailable in that environment. + // Fire responseStarted if (responseData.statusCode > 0) { BiFunction> responseStartedCb = @@ -201,21 +241,23 @@ static LambdaResponseData extractResponseData(String json) { statusCode = ((Number) statusCodeObj).intValue(); } - // Extract headers - Map headers = extractStringMap(response.get("headers")); + // Extract headers β€” keys are lowercased to normalise casing across API GW / ALB variants + Map headers = new HashMap<>(); + Map rawHeaders = extractStringMap(response.get("headers")); + for (Map.Entry entry : rawHeaders.entrySet()) { + headers.put(entry.getKey().toLowerCase(Locale.ROOT), entry.getValue()); + } - // Merge multiValueHeaders if present (API GW v1 / ALB) + // Merge multiValueHeaders if present (API GW v1 / ALB), also lowercasing keys Object multiValueHeadersObj = response.get("multiValueHeaders"); if (multiValueHeadersObj instanceof Map) { Map multiValueHeaders = (Map) multiValueHeadersObj; for (Map.Entry entry : multiValueHeaders.entrySet()) { - if (entry.getKey() != null && entry.getValue() instanceof java.util.List) { - String key = String.valueOf(entry.getKey()); - java.util.List values = (java.util.List) entry.getValue(); + if (entry.getKey() != null && entry.getValue() instanceof List) { + String key = String.valueOf(entry.getKey()).toLowerCase(Locale.ROOT); + List values = (List) entry.getValue(); String joinedValue = - values.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining(", ")); + values.stream().map(String::valueOf).collect(Collectors.joining(", ")); headers.put(key, joinedValue); } } @@ -228,8 +270,8 @@ static LambdaResponseData extractResponseData(String json) { String bodyString = String.valueOf(bodyObj); // Handle base64 encoding - Boolean isBase64Encoded = (Boolean) response.get("isBase64Encoded"); - if (Boolean.TRUE.equals(isBase64Encoded)) { + Object isBase64EncodedObj = response.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64EncodedObj) || "true".equals(isBase64EncodedObj)) { try { bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); } catch (Exception e) { @@ -239,14 +281,7 @@ static LambdaResponseData extractResponseData(String json) { } if (bodyString != null) { - // Determine content-type from response headers - String contentType = null; - for (Map.Entry entry : headers.entrySet()) { - if ("content-type".equalsIgnoreCase(entry.getKey())) { - contentType = entry.getValue(); - break; - } - } + String contentType = headers.get("content-type"); // If JSON content-type or unknown, attempt JSON parsing if (contentType == null @@ -620,7 +655,7 @@ private static LambdaEventData extractAlbData( if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { // Handle multi-value headers (combine multiple values with comma) - headers = new java.util.HashMap<>(); + headers = new HashMap<>(); Object multiValueHeadersObj = event.get("multiValueHeaders"); if (multiValueHeadersObj instanceof Map) { Map rawHeaders = (Map) multiValueHeadersObj; @@ -739,7 +774,7 @@ private static LambdaEventData extractGenericData(Map event) { * values to strings, filtering out null entries. */ private static Map extractStringMap(Object mapObj) { - Map result = new java.util.HashMap<>(); + Map result = new HashMap<>(); if (mapObj instanceof Map) { Map rawMap = (Map) mapObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -775,7 +810,7 @@ private static Map extractPathParameters(Object pathParamsObj) { * Map> format expected by AppSec. */ private static Map> extractQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); + Map> result = new HashMap<>(); if (queryParamsObj instanceof Map) { Map rawMap = (Map) queryParamsObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -795,7 +830,7 @@ private static Map> extractQueryParameters(Object queryPara * List> format directly. */ private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); + Map> result = new HashMap<>(); if (queryParamsObj instanceof Map) { Map rawMap = (Map) queryParamsObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -803,7 +838,7 @@ private static Map> extractMultiValueQueryParameters(Object String key = String.valueOf(entry.getKey()); if (entry.getValue() instanceof java.util.List) { java.util.List values = (java.util.List) entry.getValue(); - java.util.List stringValues = new java.util.ArrayList<>(); + List stringValues = new ArrayList<>(); for (Object value : values) { if (value != null) { stringValues.add(String.valueOf(value)); @@ -891,8 +926,8 @@ private static Object extractBody(Map event) { String bodyString = String.valueOf(bodyObj); // Check if body is base64 encoded (API Gateway feature) - Boolean isBase64Encoded = (Boolean) event.get("isBase64Encoded"); - if (Boolean.TRUE.equals(isBase64Encoded)) { + Object isBase64EncodedObj = event.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64EncodedObj) || "true".equals(isBase64EncodedObj)) { try { bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); } catch (Exception e) { @@ -926,6 +961,15 @@ private static Object parseBodyAsJson(String body) { } } + /** Sets the current trigger type thread-local. Package-private for use in tests only. */ + static void setCurrentTriggerType(LambdaTriggerType type) { + if (type == null) { + CURRENT_TRIGGER_TYPE.remove(); + } else { + CURRENT_TRIGGER_TYPE.set(type); + } + } + /** * Temporary RequestContext implementation to hold AppSecRequestContext before a span is created. */ @@ -988,7 +1032,11 @@ enum LambdaTriggerType { ALB, // Application Load Balancer ALB_MULTI_VALUE, // ALB with multi-value headers LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger + UNKNOWN; // Unknown or unsupported trigger + + boolean isHttp() { + return this != UNKNOWN; + } } /** Object for Lambda event data needed for AppSec processing */ diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy deleted file mode 100644 index 09dc410253c..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ /dev/null @@ -1,2058 +0,0 @@ -package datadog.trace.lambda - -import datadog.trace.api.Config -import datadog.trace.api.function.TriConsumer -import datadog.trace.api.function.TriFunction -import datadog.trace.api.gateway.CallbackProvider -import datadog.trace.api.gateway.Flow -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.gateway.RequestContextSlot -import datadog.trace.bootstrap.ActiveSubsystems -import datadog.trace.bootstrap.instrumentation.api.AgentSpan -import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext -import datadog.trace.bootstrap.instrumentation.api.AgentTracer -import datadog.trace.bootstrap.instrumentation.api.TagContext -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter -import datadog.trace.core.test.DDCoreSpecification -import spock.lang.Shared - -import java.nio.charset.StandardCharsets -import java.util.function.BiFunction -import java.util.function.Function -import java.util.function.Supplier - -import static datadog.trace.api.gateway.Events.EVENTS - -class LambdaAppSecHandlerTest extends DDCoreSpecification { - - @Shared - def originalAppSecActive - - def setupSpec() { - originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE - } - - def cleanupSpec() { - ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive - } - - def setup() { - ActiveSubsystems.APPSEC_ACTIVE = true - } - - def "processRequestStart returns null when AppSec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - def event = createInputStream('{"test": "data"}') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for non-ByteArrayInputStream"() { - when: - def result = LambdaAppSecHandler.processRequestStart("not a stream") - - then: - result == null - } - - def "processRequestStart returns null for null event"() { - when: - def result = LambdaAppSecHandler.processRequestStart(null) - - then: - result == null - } - - def "processRequestStart returns null for oversized event"() { - given: - def maxSize = Config.get().getAppSecBodyParsingSizeLimit() - def largeBody = "x" * (maxSize + 1) - def event = createInputStream(largeBody) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for zero-size event"() { - given: - def event = createInputStream('') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for malformed JSON"() { - given: - def event = createInputStream('{invalid json') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "stream can be read multiple times after processing"() { - given: - def jsonData = '{"test": "data", "requestContext": {"httpMethod": "GET"}}' - def event = createInputStream(jsonData) - - when: - LambdaAppSecHandler.processRequestStart(event) - event.reset() - def content = new String(event.bytes, StandardCharsets.UTF_8) - - then: - content == jsonData - } - - - // ============================================================================ - // Trigger Type Detection Tests - // ============================================================================ - - def "detects API Gateway v1 REST trigger type"() { - given: - def event = [ - requestContext: [ - httpMethod: "GET", - requestId: "abc123" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST - } - - def "detects API Gateway v2 HTTP trigger type"() { - given: - def event = [ - requestContext: [ - http: [ - method: "POST", - path: "/api" - ], - domainName: "api.example.com" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP - } - - def "detects Lambda Function URL trigger type"() { - given: - def event = [ - requestContext: [ - http: [ - method: "GET", - path: "/" - ], - domainName: "xyz123.lambda-url.us-east-1.on.aws" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL - } - - def "detects ALB trigger type without multi-value headers"() { - given: - def event = [ - httpMethod: "GET", - path: "/", - requestContext: [ - elb: [ - targetGroupArn: "arn:aws:..." - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB - } - - def "detects ALB trigger type with multi-value headers"() { - given: - def event = [ - httpMethod: "GET", - path: "/", - multiValueHeaders: [ - accept: ["text/html", "application/json"] - ], - requestContext: [ - elb: [ - targetGroupArn: "arn:aws:..." - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE - } - - def "detects WebSocket trigger type with routeKey"() { - given: - def event = [ - requestContext: [ - connectionId: "conn-123", - routeKey: "\$connect" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET - } - - def "detects WebSocket trigger type with eventType"() { - given: - def event = [ - requestContext: [ - connectionId: "conn-456", - eventType: "CONNECT" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET - } - - def "detects unknown trigger type for unrecognized events"() { - given: - def event = [ - someUnknownField: "value" - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN - } - - def "detects unknown trigger type for empty requestContext"() { - given: - def event = [ - requestContext: [:] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN - } - - def "detects Lambda URL when http present but no domainName"() { - given: - def event = [ - requestContext: [ - http: [ - method: "GET", - path: "/ambiguous" - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL - } - - // ============================================================================ - // Data Extraction Tests with Mocked Callbacks - // ============================================================================ - - def "extracts API Gateway v1 REST data correctly"() { - given: - def eventJson = ''' - { - "path": "/api/users/123", - "httpMethod": "POST", - "headers": { - "Content-Type": "application/json", - "Authorization": "Bearer token123" - }, - "pathParameters": { - "userId": "123" - }, - "body": "{\\"name\\": \\"John\\"}", - "requestContext": { - "httpMethod": "POST", - "requestId": "req-123", - "identity": { - "sourceIp": "192.168.1.100" - } - } - } - ''' - def event = createInputStream(eventJson) - - // Track callback invocations - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - def capturedSourcePort = null - def capturedPathParams = null - def capturedBody = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - }, - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - result instanceof TagContext - - capturedMethod == "POST" - capturedPath == "/api/users/123" - capturedHeaders["Content-Type"] == "application/json" - capturedHeaders["Authorization"] == "Bearer token123" - capturedSourceIp == "192.168.1.100" - capturedSourcePort == 0 - capturedPathParams == ["userId": "123"] - capturedBody instanceof Map - capturedBody.name == "John" - } - - def "extracts API Gateway v2 HTTP data correctly"() { - given: - def eventJson = ''' - { - "version": "2.0", - "headers": { - "content-type": "application/json", - "x-custom-header": "custom-value" - }, - "cookies": ["session=abc123", "user=john"], - "pathParameters": { - "id": "456" - }, - "body": "test body", - "requestContext": { - "http": { - "method": "PUT", - "path": "/api/items/456", - "sourceIp": "10.0.0.50", - "sourcePort": 54321 - }, - "domainName": "api.example.com" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - def capturedSourcePort = null - def capturedPathParams = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "PUT" - capturedPath == "/api/items/456" - capturedHeaders["content-type"] == "application/json" - capturedHeaders["x-custom-header"] == "custom-value" - capturedHeaders["cookie"] == "session=abc123; user=john" - capturedSourceIp == "10.0.0.50" - capturedSourcePort == 54321 - capturedPathParams == ["id": "456"] - } - - def "extracts Lambda Function URL data correctly"() { - given: - def eventJson = ''' - { - "version": "2.0", - "headers": { - "host": "xyz.lambda-url.us-east-1.on.aws" - }, - "requestContext": { - "http": { - "method": "GET", - "path": "/function/path", - "sourceIp": "1.2.3.4" - }, - "domainName": "xyz.lambda-url.us-east-1.on.aws" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "GET" - capturedPath == "/function/path" - } - - def "extracts ALB data correctly"() { - given: - def eventJson = ''' - { - "path": "/alb/test", - "httpMethod": "DELETE", - "headers": { - "x-forwarded-for": "203.0.113.42", - "user-agent": "curl/7.64.1" - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "DELETE" - capturedPath == "/alb/test" - capturedSourceIp == "203.0.113.42" - } - - def "extracts ALB multi-value headers correctly"() { - given: - def eventJson = ''' - { - "path": "/test", - "httpMethod": "GET", - "multiValueHeaders": { - "accept": ["text/html", "application/json"], - "x-custom": ["value1", "value2"] - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:..." - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["accept"] == "text/html, application/json" - capturedHeaders["x-custom"] == "value1, value2" - } - - def "handles multi-value headers with empty list"() { - given: - def eventJson = ''' - { - "path": "/test", - "httpMethod": "GET", - "multiValueHeaders": { - "accept": [], - "x-custom": ["value1"] - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:..." - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["accept"] == "" // Empty list should result in empty string - capturedHeaders["x-custom"] == "value1" - } - - def "extracts WebSocket data correctly"() { - given: - def eventJson = ''' - { - "requestContext": { - "routeKey": "$connect", - "connectionId": "conn-abc123", - "identity": { - "sourceIp": "192.168.0.100" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "WEBSOCKET" - capturedPath == "\$connect" - capturedSourceIp == "192.168.0.100" - } - - def "handles base64 encoded body correctly"() { - given: - def originalBody = "This is test data" - def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()) - def eventJson = """ - { - "body": "${base64Body}", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - """ - def event = createInputStream(eventJson) - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == originalBody - } - - def "handles null body correctly"() { - given: - def event = createInputStream('{"body": null, "requestContext": {"httpMethod": "GET"}}') - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "NOT_CALLED" // Callback should not be invoked for null body - } - - def "handles empty body correctly"() { - given: - def event = createInputStream('{"body": "", "requestContext": {"httpMethod": "POST"}}') - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "" // Empty body is passed as empty string to WAF - } - - def "handles path with query string correctly"() { - given: - def eventJson = ''' - { - "path": "/api/users?id=123&filter=active", - "requestContext": { - "httpMethod": "GET" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedPath = null - def capturedQuery = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedPath = uri.path() - capturedQuery = uri.query() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedPath == "/api/users" - capturedQuery == "id=123&filter=active" - } - - def "extracts scheme and port from X-Forwarded headers"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": { - "x-forwarded-proto": "http", - "x-forwarded-port": "8080" - }, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "http" - capturedPort == 8080 - } - - def "falls back to https/443 when X-Forwarded headers are absent"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": {}, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "https" - capturedPort == 443 - } - - def "handles invalid X-Forwarded-Port gracefully"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": { - "x-forwarded-proto": "https", - "x-forwarded-port": "not-a-number" - }, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "https" - capturedPort == 443 - } - - def "handles invalid base64 body gracefully"() { - given: - def eventJson = ''' - { - "body": "not-valid-base64", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "NOT_CALLED" // Should not call body callback when decode fails - } - - def "handles base64 decoded empty string body"() { - given: - def base64Empty = Base64.getEncoder().encodeToString("".getBytes()) - def eventJson = """ - { - "body": "${base64Empty}", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - """ - def event = createInputStream(eventJson) - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "" // Should pass empty string after decoding - } - - def "handles body with special characters"() { - given: - def eventJson = ''' - { - "body": "{\\"text\\": \\"Hello δΈ–η•Œ 🌍\\"}", - "requestContext": { - "httpMethod": "POST" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody instanceof Map - capturedBody.text == "Hello δΈ–η•Œ 🌍" - } - - // ============================================================================ - // Generic Data Extraction Tests - // ============================================================================ - - def "extracts data from unknown trigger type using generic extraction"() { - given: - def eventJson = ''' - { - "path": "/generic/path", - "httpMethod": "PATCH", - "headers": { - "x-custom-header": "generic-value" - }, - "unknownField": "should be ignored", - "requestContext": { - "identity": { - "sourceIp": "203.0.113.1" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "PATCH" - capturedPath == "/generic/path" - capturedHeaders["x-custom-header"] == "generic-value" - capturedSourceIp == "203.0.113.1" - } - - def "extracts data from unknown trigger with http in requestContext"() { - given: - def eventJson = ''' - { - "requestContext": { - "http": { - "method": "OPTIONS", - "path": "/options/path", - "sourceIp": "198.51.100.50" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "OPTIONS" - capturedPath == "/options/path" - capturedSourceIp == "198.51.100.50" - } - - def "handles cookies merging with existing cookie header"() { - given: - def eventJson = ''' - { - "headers": { - "cookie": "existing=value" - }, - "cookies": ["new=cookie1", "another=cookie2"], - "requestContext": { - "http": { - "method": "GET", - "path": "/" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["cookie"] == "existing=value; new=cookie1; another=cookie2" - } - - def "handles empty cookies array correctly"() { - given: - def eventJson = ''' - { - "headers": { - "content-type": "application/json" - }, - "cookies": [], - "requestContext": { - "http": { - "method": "GET", - "path": "/" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - !capturedHeaders.containsKey("cookie") // Empty array should not add cookie header - } - - // ============================================================================ - // processRequestEnd Tests - // ============================================================================ - - def "processRequestEnd does nothing when span is null"() { - when: - LambdaAppSecHandler.processRequestEnd(null) - - then: - noExceptionThrown() - } - - def "processRequestEnd does nothing when AppSec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - def span = Mock(AgentSpan) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - 0 * span._ - } - - def "processRequestEnd does nothing when span has no RequestContext"() { - given: - def span = Mock(AgentSpan) { - getRequestContext() >> null - } - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - noExceptionThrown() - } - - def "processRequestEnd invokes requestEnded callback with RequestContext"() { - given: - def mockAppSecContext = new Object() - def mockRequestContext = Mock(RequestContext) { - getData(RequestContextSlot.APPSEC) >> mockAppSecContext - } - def span = Mock(AgentSpan) { - getRequestContext() >> mockRequestContext - } - - def callbackInvoked = false - def capturedContext = null - def capturedSpan = null - - def mockRequestEndedCallback = Mock(BiFunction) { - apply(_ as RequestContext, _ as AgentSpan) >> { - RequestContext ctx, AgentSpan s -> - callbackInvoked = true - capturedContext = ctx - capturedSpan = s - return new Flow.ResultFlow<>(null) - } - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestEnded()) >> mockRequestEndedCallback - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - callbackInvoked - capturedContext == mockRequestContext - capturedSpan == span - } - - def "processRequestEnd handles null requestEnded callback gracefully"() { - given: - def mockRequestContext = Mock(RequestContext) - def span = Mock(AgentSpan) { - getRequestContext() >> mockRequestContext - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestEnded()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - noExceptionThrown() // Should log warning but not throw - } - - // ============================================================================ - // mergeContexts Tests - // ============================================================================ - - def "mergeContexts returns null when both contexts are null"() { - when: - def result = LambdaAppSecHandler.mergeContexts(null, null) - - then: - result == null - } - - def "mergeContexts returns extensionContext when appSecContext is null"() { - given: - def extensionContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, null) - - then: - result == extensionContext - } - - def "mergeContexts returns appSecContext when extensionContext is null"() { - given: - def appSecContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(null, appSecContext) - - then: - result == appSecContext - } - - def "mergeContexts merges AppSec data into TagContext"() { - given: - def appSecData = new Object() - - // Create real TagContext instances since methods are final - def appSecContext = new TagContext() - appSecContext.withRequestContextDataAppSec(appSecData) - - def extensionContext = new TagContext() - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - result.getRequestContextDataAppSec() == appSecData - } - - def "mergeContexts returns extensionContext when appSecContext is not TagContext"() { - given: - def extensionContext = Mock(TagContext) - def appSecContext = Mock(AgentSpanContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - } - - def "mergeContexts returns extensionContext when it is not TagContext"() { - given: - def extensionContext = Mock(AgentSpanContext) - def appSecContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - } - - // ============================================================================ - // Error Handling and Null Callback Tests - // ============================================================================ - - def "processRequestStart handles null requestStarted callback gracefully"() { - given: - def eventJson = '{"requestContext": {"httpMethod": "GET"}}' - def event = createInputStream(eventJson) - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null // Should return null when requestStarted callback is missing - } - - def "processRequestStart handles null methodUri callback gracefully"() { - given: - def eventJson = ''' - { - "path": "/test", - "requestContext": { - "httpMethod": "GET" - } - } - ''' - def event = createInputStream(eventJson) - - def mockAppSecContext = new Object() - - def mockRequestStartedCallback = Mock(Supplier) { - get() >> new Flow.ResultFlow<>(mockAppSecContext) - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback - getCallback(EVENTS.requestMethodUriRaw()) >> null // Null callback - getCallback(EVENTS.requestHeader()) >> null - getCallback(EVENTS.requestClientSocketAddress()) >> null - getCallback(EVENTS.requestHeaderDone()) >> Mock(Function) { - apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) - } - getCallback(EVENTS.requestPathParams()) >> null - getCallback(EVENTS.requestBodyProcessed()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null // Should continue processing even if methodUri callback is null - result instanceof TagContext - } - - def "processRequestStart handles exception during JSON parsing"() { - given: - def invalidJson = '{this is not valid JSON at all' - def event = createInputStream(invalidJson) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null // Should return null on parse error - } - - def "processRequestStart handles exception during stream reading"() { - given: - def mockStream = Mock(ByteArrayInputStream) { - available() >> { throw new IOException("Stream error") } - } - - when: - def result = LambdaAppSecHandler.processRequestStart(mockStream) - - then: - result == null // Should return null on IO error - } - - // ============================================================================ - // Helper Methods - // ============================================================================ - - private ByteArrayInputStream createInputStream(String json) { - return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) - } - - /** - * Set up mock callbacks to capture invocations and verify data extraction. - * This mocks the AgentTracer and callback provider to intercept gateway calls. - */ - private void setupMockCallbacks(Map callbacks) { - def mockAppSecContext = new Object() - - def mockRequestStartedCallback = Mock(Supplier) { - get() >> new Flow.ResultFlow<>(mockAppSecContext) - } - - def mockMethodUriCallback = callbacks.onMethodUri ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { - RequestContext ctx, String method, URIDataAdapter uri -> - callbacks.onMethodUri(method, uri) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { - accept(_ as RequestContext, _ as String, _ as String) >> { - RequestContext ctx, String name, String value -> - callbacks.onHeader(name, value) - } - } : null - - def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as Integer) >> { - RequestContext ctx, String ip, Integer port -> - callbacks.onSocketAddress(ip, port) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockHeaderDoneCallback = Mock(Function) { - apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) - } - - def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { - RequestContext ctx, Map params -> - callbacks.onPathParams(params) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Object) >> { - RequestContext ctx, Object body -> - callbacks.onBody(body) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback - getCallback(EVENTS.requestMethodUriRaw()) >> mockMethodUriCallback - getCallback(EVENTS.requestHeader()) >> mockHeaderCallback - getCallback(EVENTS.requestClientSocketAddress()) >> mockSocketAddressCallback - getCallback(EVENTS.requestHeaderDone()) >> mockHeaderDoneCallback - getCallback(EVENTS.requestPathParams()) >> mockPathParamsCallback - getCallback(EVENTS.requestBodyProcessed()) >> mockBodyCallback - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - // Install the mock tracer - AgentTracer.forceRegister(mockTracer) - } - - // ============================================================================ - // processResponseData Tests - // ============================================================================ - - def "processResponseData does nothing when AppSec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - def span = Mock(AgentSpan) - def result = createOutputStream('{"statusCode": 200, "body": "ok"}') - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - 0 * span._ - } - - def "processResponseData does nothing for null span"() { - given: - def result = createOutputStream('{"statusCode": 200}') - - when: - LambdaAppSecHandler.processResponseData(null, result) - - then: - noExceptionThrown() - } - - def "processResponseData does nothing for non-ByteArrayOutputStream result"() { - given: - def span = Mock(AgentSpan) - - when: - LambdaAppSecHandler.processResponseData(span, "string result") - - then: - 0 * span._ - } - - def "processResponseData does nothing for null result"() { - given: - def span = Mock(AgentSpan) - - when: - LambdaAppSecHandler.processResponseData(span, null) - - then: - 0 * span._ - } - - def "processResponseData does nothing when span has no RequestContext"() { - given: - def span = Mock(AgentSpan) { - getRequestContext() >> null - } - def result = createOutputStream('{"statusCode": 200}') - - setupMockResponseCallbacks([:]) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - noExceptionThrown() - } - - def "processResponseData does nothing for oversized response"() { - given: - def maxSize = Config.get().getAppSecBodyParsingSizeLimit() - def largeBody = "x" * (maxSize + 1) - def result = createOutputStream(largeBody) - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedStatus == null - } - - def "processResponseData does nothing for empty ByteArrayOutputStream"() { - given: - def result = new ByteArrayOutputStream() // 0 bytes - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedStatus == null - } - - // --- Status code extraction --- - - def "processResponseData extracts statusCode correctly"() { - given: - def result = createOutputStream('{"statusCode": 200, "body": "ok"}') - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedStatus == 200 - } - - def "processResponseData extracts statusCode as integer from double"() { - given: - def result = createOutputStream('{"statusCode": 404.0, "body": "not found"}') - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedStatus == 404 - } - - def "processResponseData handles missing statusCode"() { - given: - def result = createOutputStream('{"body": "ok"}') - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedStatus == null // statusCode is 0, so responseStarted is not called - } - - def "processResponseData handles non-numeric statusCode"() { - given: - def result = createOutputStream('{"statusCode": "bad", "body": "ok"}') - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - noExceptionThrown() - capturedStatus == null // "bad" is not a Number, statusCode stays 0 - } - - // --- Header extraction --- - - def "processResponseData forwards all response headers"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "application/json", "x-custom": "val", "content-length": "42", "set-cookie": "a=1"}}' - def result = createOutputStream(json) - def capturedHeaders = [:] - - def span = setupMockResponseCallbacks( - onResponseHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedHeaders.size() == 4 - capturedHeaders["content-type"] == "application/json" - capturedHeaders["x-custom"] == "val" - capturedHeaders["content-length"] == "42" - capturedHeaders["set-cookie"] == "a=1" - } - - def "processResponseData preserves original header casing"() { - given: - def json = '{"statusCode": 200, "headers": {"Content-Type": "text/html", "CONTENT-LENGTH": "10"}}' - def result = createOutputStream(json) - def capturedHeaders = [:] - - def span = setupMockResponseCallbacks( - onResponseHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedHeaders["Content-Type"] == "text/html" - capturedHeaders["CONTENT-LENGTH"] == "10" - } - - def "processResponseData merges multiValueHeaders with single-value headers"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "text/html"}, "multiValueHeaders": {"content-encoding": ["gzip", "br"]}}' - def result = createOutputStream(json) - def capturedHeaders = [:] - - def span = setupMockResponseCallbacks( - onResponseHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedHeaders["content-type"] == "text/html" - capturedHeaders["content-encoding"] == "gzip, br" - } - - def "processResponseData handles empty headers"() { - given: - def result = createOutputStream('{"statusCode": 200}') - def capturedHeaders = [:] - def headerDoneCalled = false - - def span = setupMockResponseCallbacks( - onResponseHeader: { name, value -> capturedHeaders[name] = value }, - onResponseHeaderDone: { - headerDoneCalled = true - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedHeaders.isEmpty() - headerDoneCalled - } - - // --- Body extraction --- - - def "processResponseData parses JSON body"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "application/json"}, "body": "{\\"key\\": \\"value\\"}"}' - def result = createOutputStream(json) - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody instanceof Map - capturedBody["key"] == "value" - } - - def "processResponseData handles non-JSON body as raw string"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "text/plain"}, "body": "plain text"}' - def result = createOutputStream(json) - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody == "plain text" - } - - def "processResponseData handles base64 encoded body"() { - given: - def originalBody = '{"decoded": "content"}' - def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes(StandardCharsets.UTF_8)) - def json = """{"statusCode": 200, "body": "${base64Body}", "isBase64Encoded": true}""" - def result = createOutputStream(json) - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody instanceof Map - capturedBody["decoded"] == "content" - } - - def "processResponseData handles null body"() { - given: - def result = createOutputStream('{"statusCode": 200, "body": null}') - def capturedBody = "NOT_CALLED" - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody == "NOT_CALLED" - } - - def "processResponseData handles missing body field"() { - given: - def result = createOutputStream('{"statusCode": 200}') - def capturedBody = "NOT_CALLED" - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody == "NOT_CALLED" - } - - def "processResponseData attempts JSON parse when no content-type"() { - given: - def result = createOutputStream('{"statusCode": 200, "body": "{\\"a\\": 1}"}') - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody instanceof Map - capturedBody["a"] == 1.0d // Moshi parses numbers as Double - } - - def "processResponseData falls back to raw string when JSON parse fails"() { - given: - def result = createOutputStream('{"statusCode": 200, "body": "not json {"}') - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody == "not json {" - } - - // --- Event ordering --- - - def "processResponseData fires events in correct order"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "application/json"}, "body": "{\\"k\\": \\"v\\"}"}' - def result = createOutputStream(json) - def order = [] - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> order << "responseStarted" }, - onResponseHeader: { name, value -> order << "responseHeader" }, - onResponseHeaderDone: { order << "responseHeaderDone" }, - onResponseBody: { body -> - order << "responseBody" - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - order[0] == "responseStarted" - order.findAll { it == "responseHeader" }.size() >= 1 - def headerDoneIdx = order.indexOf("responseHeaderDone") - def lastHeaderIdx = order.lastIndexOf("responseHeader") - headerDoneIdx > lastHeaderIdx - order.last() == "responseBody" - } - - def "processResponseData handles invalid base64 response body gracefully"() { - given: - def json = '{"statusCode": 200, "body": "not-valid-base64!!!", "isBase64Encoded": true}' - def result = createOutputStream(json) - def capturedBody = "NOT_CALLED" - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - noExceptionThrown() - capturedBody == "NOT_CALLED" - } - - def "processResponseData parses body as JSON for javascript content-type"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "application/javascript"}, "body": "{\\"key\\": \\"val\\"}"}' - def result = createOutputStream(json) - def capturedBody = null - - def span = setupMockResponseCallbacks( - onResponseBody: { body -> - capturedBody = body - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedBody instanceof Map - capturedBody["key"] == "val" - } - - def "processResponseData multiValueHeaders override single-value headers"() { - given: - def json = '{"statusCode": 200, "headers": {"content-type": "text/html"}, "multiValueHeaders": {"content-type": ["application/json", "charset=utf-8"]}}' - def result = createOutputStream(json) - def capturedHeaders = [:] - - def span = setupMockResponseCallbacks( - onResponseHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - capturedHeaders["content-type"] == "application/json, charset=utf-8" - } - - // --- Error handling --- - - def "processResponseData handles malformed JSON response"() { - given: - def result = createOutputStream('{not valid json') - def capturedStatus = null - - def span = setupMockResponseCallbacks( - onResponseStarted: { status -> - capturedStatus = status - } - ) - - when: - LambdaAppSecHandler.processResponseData(span, result) - - then: - noExceptionThrown() - capturedStatus == null - } - - def "processResponseData handles empty string response"() { - given: - // Empty string inside a non-empty BAOS (the string "") - // This would fail JSON parsing - def result = createOutputStream('') - - when: - LambdaAppSecHandler.processResponseData(Mock(AgentSpan), result) - - then: - noExceptionThrown() - } - - // ============================================================================ - // extractResponseData Unit Tests - // ============================================================================ - - def "extractResponseData returns null for malformed JSON"() { - when: - def result = LambdaAppSecHandler.extractResponseData('{bad json') - - then: - result == null - } - - def "extractResponseData returns null for null JSON parse result"() { - when: - def result = LambdaAppSecHandler.extractResponseData('null') - - then: - result == null - } - - // ============================================================================ - // Helper Methods - // ============================================================================ - - private ByteArrayOutputStream createOutputStream(String json) { - def baos = new ByteArrayOutputStream() - baos.write(json.getBytes(StandardCharsets.UTF_8)) - return baos - } - - /** - * Set up mock response callbacks and return a mock span with a valid RequestContext. - * processResponseData uses span.getRequestContext() to get the RequestContext, - * unlike processRequestStart which uses TemporaryRequestContext. - */ - private AgentSpan setupMockResponseCallbacks(Map callbacks) { - def mockAppSecContext = new Object() - def mockRequestContext = Mock(RequestContext) { - getData(RequestContextSlot.APPSEC) >> mockAppSecContext - } - def mockSpan = Mock(AgentSpan) { - getRequestContext() >> mockRequestContext - } - - def mockResponseStartedCb = callbacks.onResponseStarted ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Integer) >> { - RequestContext ctx, Integer status -> - callbacks.onResponseStarted(status) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockResponseHeaderCb = callbacks.onResponseHeader ? Mock(TriConsumer) { - accept(_ as RequestContext, _ as String, _ as String) >> { - RequestContext ctx, String name, String value -> - callbacks.onResponseHeader(name, value) - } - } : null - - def mockResponseHeaderDoneCb = callbacks.onResponseHeaderDone ? Mock(Function) { - apply(_ as RequestContext) >> { - callbacks.onResponseHeaderDone() - return new Flow.ResultFlow<>(null) - } - } : Mock(Function) { - apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) - } - - def mockResponseBodyCb = callbacks.onResponseBody ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Object) >> { - RequestContext ctx, Object body -> - callbacks.onResponseBody(body) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.responseStarted()) >> mockResponseStartedCb - getCallback(EVENTS.responseHeader()) >> mockResponseHeaderCb - getCallback(EVENTS.responseHeaderDone()) >> mockResponseHeaderDoneCb - getCallback(EVENTS.responseBody()) >> mockResponseBodyCb - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - return mockSpan - } - - def cleanup() { - // Reset tracer after each test - AgentTracer.forceRegister(null) - } -} diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java index eed9ac52e9b..4b39677f046 100644 --- a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java @@ -6,10 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import datadog.trace.api.Config; @@ -28,114 +32,101 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; import datadog.trace.core.DDCoreJavaSpecification; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Base64; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -@SuppressWarnings("unchecked") -public class LambdaAppSecHandlerTest extends DDCoreJavaSpecification { +class LambdaAppSecHandlerTest extends DDCoreJavaSpecification { - private static boolean originalAppSecActive; - private static AgentTracer.TracerAPI originalTracer; + static boolean originalAppSecActive; + static AgentTracer.TracerAPI originalTracer; @BeforeAll - static void setupSpec() { + static void saveState() { originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE; originalTracer = AgentTracer.get(); } + @AfterAll + static void restoreAppSecState() { + ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive; + } + @BeforeEach - void setup() { + void enableAppSec() { ActiveSubsystems.APPSEC_ACTIVE = true; } @AfterEach - void cleanup() { - ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive; + void resetTracer() { AgentTracer.forceRegister(originalTracer); + LambdaAppSecHandler.setCurrentTriggerType(null); } // ============================================================================ - // processRequestStart basic tests + // processRequestStart β€” guard tests // ============================================================================ @Test void processRequestStartReturnsNullWhenAppSecIsDisabled() { ActiveSubsystems.APPSEC_ACTIVE = false; ByteArrayInputStream event = createInputStream("{\"test\": \"data\"}"); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test void processRequestStartReturnsNullForNonByteArrayInputStream() { - AgentSpanContext result = LambdaAppSecHandler.processRequestStart("not a stream"); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart("not a stream")); } @Test void processRequestStartReturnsNullForNullEvent() { - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(null); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(null)); } @Test void processRequestStartReturnsNullForOversizedEvent() { int maxSize = Config.get().getAppSecBodyParsingSizeLimit(); - String largeBody = repeatChar('x', maxSize + 1); - ByteArrayInputStream event = createInputStream(largeBody); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + char[] chars = new char[maxSize + 1]; + java.util.Arrays.fill(chars, 'x'); + ByteArrayInputStream event = createInputStream(new String(chars)); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test void processRequestStartReturnsNullForZeroSizeEvent() { ByteArrayInputStream event = createInputStream(""); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test - void processRequestStartReturnsNullForMalformedJSON() { + void processRequestStartReturnsNullForMalformedJson() { ByteArrayInputStream event = createInputStream("{invalid json"); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test - void streamCanBeReadMultipleTimesAfterProcessing() throws Exception { + void streamCanBeReadMultipleTimesAfterProcessing() throws IOException { String jsonData = "{\"test\": \"data\", \"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(jsonData); - LambdaAppSecHandler.processRequestStart(event); event.reset(); byte[] bytes = new byte[event.available()]; event.read(bytes); String content = new String(bytes, StandardCharsets.UTF_8); - assertEquals(jsonData, content); } @@ -145,133 +136,130 @@ void streamCanBeReadMultipleTimesAfterProcessing() throws Exception { @Test void detectsApiGatewayV1RestTriggerType() { - Map event = - mapOf("requestContext", mapOf("httpMethod", "GET", "requestId", "abc123")); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST, triggerType); + Map event = new HashMap<>(); + Map requestContext = new HashMap<>(); + requestContext.put("httpMethod", "GET"); + requestContext.put("requestId", "abc123"); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsApiGatewayV2HttpTriggerType() { - Map event = - mapOf( - "requestContext", - mapOf( - "http", mapOf("method", "POST", "path", "/api"), "domainName", "api.example.com")); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP, triggerType); + Map http = new HashMap<>(); + http.put("method", "POST"); + http.put("path", "/api"); + Map requestContext = new HashMap<>(); + requestContext.put("http", http); + requestContext.put("domainName", "api.example.com"); + Map event = new HashMap<>(); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsLambdaFunctionUrlTriggerType() { - Map event = - mapOf( - "requestContext", - mapOf( - "http", - mapOf("method", "GET", "path", "/"), - "domainName", - "xyz123.lambda-url.us-east-1.on.aws")); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); + Map http = new HashMap<>(); + http.put("method", "GET"); + http.put("path", "/"); + Map requestContext = new HashMap<>(); + requestContext.put("http", http); + requestContext.put("domainName", "xyz123.lambda-url.us-east-1.on.aws"); + Map event = new HashMap<>(); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsAlbTriggerTypeWithoutMultiValueHeaders() { - Map event = - mapOf( - "httpMethod", - "GET", - "path", - "/", - "requestContext", - mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB, triggerType); + Map elb = new HashMap<>(); + elb.put("targetGroupArn", "arn:aws:..."); + Map requestContext = new HashMap<>(); + requestContext.put("elb", elb); + Map event = new HashMap<>(); + event.put("httpMethod", "GET"); + event.put("path", "/"); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.ALB, LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsAlbTriggerTypeWithMultiValueHeaders() { - Map event = - mapOf( - "httpMethod", - "GET", - "path", - "/", - "multiValueHeaders", - mapOf("accept", Arrays.asList("text/html", "application/json")), - "requestContext", - mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE, triggerType); + Map elb = new HashMap<>(); + elb.put("targetGroupArn", "arn:aws:..."); + Map requestContext = new HashMap<>(); + requestContext.put("elb", elb); + Map event = new HashMap<>(); + event.put("httpMethod", "GET"); + event.put("path", "/"); + event.put("multiValueHeaders", new HashMap<>()); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsWebSocketTriggerTypeWithRouteKey() { - Map event = - mapOf("requestContext", mapOf("connectionId", "conn-123", "routeKey", "$connect")); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); + Map requestContext = new HashMap<>(); + requestContext.put("connectionId", "conn-123"); + requestContext.put("routeKey", "$connect"); + Map event = new HashMap<>(); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsWebSocketTriggerTypeWithEventType() { - Map event = - mapOf("requestContext", mapOf("connectionId", "conn-456", "eventType", "CONNECT")); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); + Map requestContext = new HashMap<>(); + requestContext.put("connectionId", "conn-456"); + requestContext.put("eventType", "CONNECT"); + Map event = new HashMap<>(); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsUnknownTriggerTypeForUnrecognizedEvents() { - Map event = mapOf("someUnknownField", "value"); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); + Map event = new HashMap<>(); + event.put("someUnknownField", "value"); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsUnknownTriggerTypeForEmptyRequestContext() { - Map event = mapOf("requestContext", mapOf()); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); + Map event = new HashMap<>(); + event.put("requestContext", new HashMap<>()); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, + LambdaAppSecHandler.detectTriggerType(event)); } @Test void detectsLambdaUrlWhenHttpPresentButNoDomainName() { - Map event = - mapOf("requestContext", mapOf("http", mapOf("method", "GET", "path", "/ambiguous"))); - - LambdaAppSecHandler.LambdaTriggerType triggerType = - LambdaAppSecHandler.detectTriggerType(event); - - assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); + Map http = new HashMap<>(); + http.put("method", "GET"); + http.put("path", "/ambiguous"); + Map requestContext = new HashMap<>(); + requestContext.put("http", http); + Map event = new HashMap<>(); + event.put("requestContext", requestContext); + assertEquals( + LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, + LambdaAppSecHandler.detectTriggerType(event)); } // ============================================================================ @@ -279,26 +267,20 @@ void detectsLambdaUrlWhenHttpPresentButNoDomainName() { // ============================================================================ @Test + @SuppressWarnings("unchecked") void extractsApiGatewayV1RestDataCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/api/users/123\",\n" - + " \"httpMethod\": \"POST\",\n" - + " \"headers\": {\n" - + " \"Content-Type\": \"application/json\",\n" - + " \"Authorization\": \"Bearer token123\"\n" - + " },\n" - + " \"pathParameters\": {\n" - + " \"userId\": \"123\"\n" - + " },\n" - + " \"body\": \"{\\\"name\\\": \\\"John\\\"}\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\",\n" - + " \"requestId\": \"req-123\",\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"192.168.1.100\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/api/users/123\"," + + "\"httpMethod\": \"POST\"," + + "\"headers\": {\"Content-Type\": \"application/json\", \"Authorization\": \"Bearer token123\"}," + + "\"pathParameters\": {\"userId\": \"123\"}," + + "\"body\": \"{\\\"name\\\": \\\"John\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"POST\"," + + " \"requestId\": \"req-123\"," + + " \"identity\": {\"sourceIp\": \"192.168.1.100\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -306,38 +288,35 @@ void extractsApiGatewayV1RestDataCorrectly() { String[] capturedPath = {null}; Map capturedHeaders = new HashMap<>(); String[] capturedSourceIp = {null}; - Integer[] capturedSourcePort = {null}; - Map[] capturedPathParams = new Map[] {null}; + int[] capturedSourcePort = {-1}; + Map[] capturedPathParams = {null}; Object[] capturedBody = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) - .onSocketAddress( - (ip, port) -> { - capturedSourceIp[0] = ip; - capturedSourcePort[0] = port; - }) - .onPathParams(params -> capturedPathParams[0] = params) - .onBody(body -> capturedBody[0] = body)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + capturedHeaders::put, + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }, + params -> capturedPathParams[0] = params, + body -> capturedBody[0] = body); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); assertInstanceOf(TagContext.class, result); - assertEquals("POST", capturedMethod[0]); assertEquals("/api/users/123", capturedPath[0]); assertEquals("application/json", capturedHeaders.get("Content-Type")); assertEquals("Bearer token123", capturedHeaders.get("Authorization")); assertEquals("192.168.1.100", capturedSourceIp[0]); - assertEquals(Integer.valueOf(0), capturedSourcePort[0]); - assertEquals("123", ((Map) capturedPathParams[0]).get("userId")); + assertEquals(0, capturedSourcePort[0]); + assertNotNull(capturedPathParams[0]); + assertEquals("123", capturedPathParams[0].get("userId")); assertInstanceOf(Map.class, capturedBody[0]); assertEquals("John", ((Map) capturedBody[0]).get("name")); } @@ -345,26 +324,16 @@ void extractsApiGatewayV1RestDataCorrectly() { @Test void extractsApiGatewayV2HttpDataCorrectly() { String eventJson = - "{\n" - + " \"version\": \"2.0\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-custom-header\": \"custom-value\"\n" - + " },\n" - + " \"cookies\": [\"session=abc123\", \"user=john\"],\n" - + " \"pathParameters\": {\n" - + " \"id\": \"456\"\n" - + " },\n" - + " \"body\": \"test body\",\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"PUT\",\n" - + " \"path\": \"/api/items/456\",\n" - + " \"sourceIp\": \"10.0.0.50\",\n" - + " \"sourcePort\": 54321\n" - + " },\n" - + " \"domainName\": \"api.example.com\"\n" - + " }\n" + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"content-type\": \"application/json\", \"x-custom-header\": \"custom-value\"}," + + "\"cookies\": [\"session=abc123\", \"user=john\"]," + + "\"pathParameters\": {\"id\": \"456\"}," + + "\"body\": \"test body\"," + + "\"requestContext\": {" + + " \"http\": {\"method\": \"PUT\", \"path\": \"/api/items/456\", \"sourceIp\": \"10.0.0.50\", \"sourcePort\": 54321}," + + " \"domainName\": \"api.example.com\"" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -372,23 +341,21 @@ void extractsApiGatewayV2HttpDataCorrectly() { String[] capturedPath = {null}; Map capturedHeaders = new HashMap<>(); String[] capturedSourceIp = {null}; - Integer[] capturedSourcePort = {null}; - Map[] capturedPathParams = new Map[] {null}; + int[] capturedSourcePort = {-1}; + Map[] capturedPathParams = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) - .onSocketAddress( - (ip, port) -> { - capturedSourceIp[0] = ip; - capturedSourcePort[0] = port; - }) - .onPathParams(params -> capturedPathParams[0] = params)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + capturedHeaders::put, + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }, + params -> capturedPathParams[0] = params, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -399,26 +366,21 @@ void extractsApiGatewayV2HttpDataCorrectly() { assertEquals("custom-value", capturedHeaders.get("x-custom-header")); assertEquals("session=abc123; user=john", capturedHeaders.get("cookie")); assertEquals("10.0.0.50", capturedSourceIp[0]); - assertEquals(Integer.valueOf(54321), capturedSourcePort[0]); - assertEquals("456", ((Map) capturedPathParams[0]).get("id")); + assertEquals(54321, capturedSourcePort[0]); + assertNotNull(capturedPathParams[0]); + assertEquals("456", capturedPathParams[0].get("id")); } @Test void extractsLambdaFunctionUrlDataCorrectly() { String eventJson = - "{\n" - + " \"version\": \"2.0\",\n" - + " \"headers\": {\n" - + " \"host\": \"xyz.lambda-url.us-east-1.on.aws\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/function/path\",\n" - + " \"sourceIp\": \"1.2.3.4\"\n" - + " },\n" - + " \"domainName\": \"xyz.lambda-url.us-east-1.on.aws\"\n" - + " }\n" + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"host\": \"xyz.lambda-url.us-east-1.on.aws\"}," + + "\"requestContext\": {" + + " \"http\": {\"method\": \"GET\", \"path\": \"/function/path\", \"sourceIp\": \"1.2.3.4\"}," + + " \"domainName\": \"xyz.lambda-url.us-east-1.on.aws\"" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -426,12 +388,14 @@ void extractsLambdaFunctionUrlDataCorrectly() { String[] capturedPath = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - })); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + null, + null, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -443,18 +407,13 @@ void extractsLambdaFunctionUrlDataCorrectly() { @Test void extractsAlbDataCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/alb/test\",\n" - + " \"httpMethod\": \"DELETE\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-for\": \"203.0.113.42\",\n" - + " \"user-agent\": \"curl/7.64.1\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/alb/test\"," + + "\"httpMethod\": \"DELETE\"," + + "\"headers\": {\"x-forwarded-for\": \"203.0.113.42\", \"user-agent\": \"curl/7.64.1\"}," + + "\"requestContext\": {" + + " \"elb\": {\"targetGroupArn\": \"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/tg/50dc6c495c0c9188\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -463,13 +422,14 @@ void extractsAlbDataCorrectly() { String[] capturedSourceIp = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + null, + (ip, port) -> capturedSourceIp[0] = ip, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -482,24 +442,17 @@ void extractsAlbDataCorrectly() { @Test void extractsAlbMultiValueHeadersCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"httpMethod\": \"GET\",\n" - + " \"multiValueHeaders\": {\n" - + " \"accept\": [\"text/html\", \"application/json\"],\n" - + " \"x-custom\": [\"value1\", \"value2\"]\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:...\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/test\"," + + "\"httpMethod\": \"GET\"," + + "\"multiValueHeaders\": {\"accept\": [\"text/html\", \"application/json\"], \"x-custom\": [\"value1\", \"value2\"]}," + + "\"requestContext\": {\"elb\": {\"targetGroupArn\": \"arn:aws:...\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(null, capturedHeaders::put, null, null, null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -511,43 +464,34 @@ void extractsAlbMultiValueHeadersCorrectly() { @Test void handlesMultiValueHeadersWithEmptyList() { String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"httpMethod\": \"GET\",\n" - + " \"multiValueHeaders\": {\n" - + " \"accept\": [],\n" - + " \"x-custom\": [\"value1\"]\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:...\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/test\"," + + "\"httpMethod\": \"GET\"," + + "\"multiValueHeaders\": {\"accept\": [], \"x-custom\": [\"value1\"]}," + + "\"requestContext\": {\"elb\": {\"targetGroupArn\": \"arn:aws:...\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(null, capturedHeaders::put, null, null, null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedHeaders.get("accept")); // Empty list should result in empty string + assertEquals("", capturedHeaders.get("accept")); assertEquals("value1", capturedHeaders.get("x-custom")); } @Test void extractsWebSocketDataCorrectly() { String eventJson = - "{\n" - + " \"requestContext\": {\n" - + " \"routeKey\": \"$connect\",\n" - + " \"connectionId\": \"conn-abc123\",\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"192.168.0.100\"\n" - + " }\n" - + " }\n" + "{" + + "\"requestContext\": {" + + " \"routeKey\": \"$connect\"," + + " \"connectionId\": \"conn-abc123\"," + + " \"identity\": {\"sourceIp\": \"192.168.0.100\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -556,13 +500,14 @@ void extractsWebSocketDataCorrectly() { String[] capturedSourceIp = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + null, + (ip, port) -> capturedSourceIp[0] = ip, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -577,20 +522,18 @@ void handlesBase64EncodedBodyCorrectly() { String originalBody = "This is test data"; String base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()); String eventJson = - "{\n" - + " \"body\": \"" + "{" + + "\"body\": \"" + base64Body - + "\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + + "\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Object[] capturedBody = {null}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -605,12 +548,12 @@ void handlesNullBodyCorrectly() { String[] capturedBody = {"NOT_CALLED"}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = String.valueOf(body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("NOT_CALLED", capturedBody[0]); // Callback should not be invoked for null body + assertEquals("NOT_CALLED", capturedBody[0]); } @Test @@ -620,22 +563,20 @@ void handlesEmptyBodyCorrectly() { Object[] capturedBody = {null}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedBody[0]); // Empty body is passed as empty string to WAF + assertEquals("", capturedBody[0]); } @Test void handlesPathWithQueryStringCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/api/users?id=123&filter=active\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\"\n" - + " }\n" + "{" + + "\"path\": \"/api/users?id=123&filter=active\"," + + "\"requestContext\": {\"httpMethod\": \"GET\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -643,12 +584,14 @@ void handlesPathWithQueryStringCorrectly() { String[] capturedQuery = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedPath[0] = uri.path(); - capturedQuery[0] = uri.query(); - })); + (method, uri) -> { + capturedPath[0] = uri.path(); + capturedQuery[0] = uri.query(); + }, + null, + null, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -660,163 +603,149 @@ void handlesPathWithQueryStringCorrectly() { @Test void extractsSchemeAndPortFromXForwardedHeaders() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-proto\": \"http\",\n" - + " \"x-forwarded-port\": \"8080\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"x-forwarded-proto\": \"http\", \"x-forwarded-port\": \"8080\"}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - })); + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + }, + null, + null, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); assertEquals("http", capturedScheme[0]); - assertEquals(Integer.valueOf(8080), capturedPort[0]); + assertEquals(8080, capturedPort[0]); } @Test void fallsBackToHttps443WhenXForwardedHeadersAreAbsent() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {},\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - })); + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + }, + null, + null, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); assertEquals("https", capturedScheme[0]); - assertEquals(Integer.valueOf(443), capturedPort[0]); + assertEquals(443, capturedPort[0]); } @Test void handlesInvalidXForwardedPortGracefully() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-proto\": \"https\",\n" - + " \"x-forwarded-port\": \"not-a-number\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"x-forwarded-proto\": \"https\", \"x-forwarded-port\": \"not-a-number\"}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - })); + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + }, + null, + null, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); assertEquals("https", capturedScheme[0]); - assertEquals(Integer.valueOf(443), capturedPort[0]); + assertEquals(443, capturedPort[0]); } @Test void handlesInvalidBase64BodyGracefully() { String eventJson = - "{\n" - + " \"body\": \"not-valid-base64\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + "{" + + "\"body\": \"not-valid-base64\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedBody = {"NOT_CALLED"}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = String.valueOf(body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("NOT_CALLED", capturedBody[0]); // Should not call body callback when decode fails + assertEquals("NOT_CALLED", capturedBody[0]); } @Test void handlesBase64DecodedEmptyStringBody() { String base64Empty = Base64.getEncoder().encodeToString("".getBytes()); String eventJson = - "{\n" - + " \"body\": \"" + "{" + + "\"body\": \"" + base64Empty - + "\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + + "\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); - Object[] capturedBody = {"NOT_CALLED"}; + Object[] capturedBody = {null}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedBody[0]); // Should pass empty string after decoding + assertEquals("", capturedBody[0]); } @Test + @SuppressWarnings("unchecked") void handlesBodyWithSpecialCharacters() { String eventJson = - "{\n" - + " \"body\": \"{\\\"text\\\": \\\"Hello \\u4e16\\u754c \\uD83C\\uDF0D\\\"}\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + "{" + + "\"body\": \"{\\\"text\\\": \\\"Hello \\u4e16\\u754c \\uD83C\\uDF0D\\\"}\"," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Object[] capturedBody = {null}; - setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -832,18 +761,12 @@ void handlesBodyWithSpecialCharacters() { @Test void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { String eventJson = - "{\n" - + " \"path\": \"/generic/path\",\n" - + " \"httpMethod\": \"PATCH\",\n" - + " \"headers\": {\n" - + " \"x-custom-header\": \"generic-value\"\n" - + " },\n" - + " \"unknownField\": \"should be ignored\",\n" - + " \"requestContext\": {\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"203.0.113.1\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/generic/path\"," + + "\"httpMethod\": \"PATCH\"," + + "\"headers\": {\"x-custom-header\": \"generic-value\"}," + + "\"unknownField\": \"should be ignored\"," + + "\"requestContext\": {\"identity\": {\"sourceIp\": \"203.0.113.1\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -853,14 +776,14 @@ void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { String[] capturedSourceIp = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) - .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + capturedHeaders::put, + (ip, port) -> capturedSourceIp[0] = ip, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -874,14 +797,10 @@ void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { @Test void extractsDataFromUnknownTriggerWithHttpInRequestContext() { String eventJson = - "{\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"OPTIONS\",\n" - + " \"path\": \"/options/path\",\n" - + " \"sourceIp\": \"198.51.100.50\"\n" - + " }\n" - + " }\n" + "{" + + "\"requestContext\": {" + + " \"http\": {\"method\": \"OPTIONS\", \"path\": \"/options/path\", \"sourceIp\": \"198.51.100.50\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -890,13 +809,14 @@ void extractsDataFromUnknownTriggerWithHttpInRequestContext() { String[] capturedSourceIp = {null}; setupMockCallbacks( - new Callbacks() - .onMethodUri( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }) - .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }, + null, + (ip, port) -> capturedSourceIp[0] = ip, + null, + null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -909,23 +829,16 @@ void extractsDataFromUnknownTriggerWithHttpInRequestContext() { @Test void handlesCookiesMergingWithExistingCookieHeader() { String eventJson = - "{\n" - + " \"headers\": {\n" - + " \"cookie\": \"existing=value\"\n" - + " },\n" - + " \"cookies\": [\"new=cookie1\", \"another=cookie2\"],\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/\"\n" - + " }\n" - + " }\n" + "{" + + "\"headers\": {\"cookie\": \"existing=value\"}," + + "\"cookies\": [\"new=cookie1\", \"another=cookie2\"]," + + "\"requestContext\": {\"http\": {\"method\": \"GET\", \"path\": \"/\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(null, capturedHeaders::put, null, null, null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -936,28 +849,21 @@ void handlesCookiesMergingWithExistingCookieHeader() { @Test void handlesEmptyCookiesArrayCorrectly() { String eventJson = - "{\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"cookies\": [],\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/\"\n" - + " }\n" - + " }\n" + "{" + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"cookies\": []," + + "\"requestContext\": {\"http\": {\"method\": \"GET\", \"path\": \"/\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(null, capturedHeaders::put, null, null, null); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertFalse(capturedHeaders.containsKey("cookie")); // Empty array should not add cookie header + assertTrue(!capturedHeaders.containsKey("cookie")); } // ============================================================================ @@ -966,30 +872,28 @@ void handlesEmptyCookiesArrayCorrectly() { @Test void processRequestEndDoesNothingWhenSpanIsNull() { - // No exception should be thrown LambdaAppSecHandler.processRequestEnd(null); + // no exception expected } @Test void processRequestEndDoesNothingWhenAppSecIsDisabled() { ActiveSubsystems.APPSEC_ACTIVE = false; AgentSpan span = mock(AgentSpan.class); - LambdaAppSecHandler.processRequestEnd(span); - - verifyNoInteractions(span); + verify(span, never()).getRequestContext(); } @Test void processRequestEndDoesNothingWhenSpanHasNoRequestContext() { AgentSpan span = mock(AgentSpan.class); when(span.getRequestContext()).thenReturn(null); - - // No exception should be thrown LambdaAppSecHandler.processRequestEnd(span); + // no exception expected } @Test + @SuppressWarnings("unchecked") void processRequestEndInvokesRequestEndedCallbackWithRequestContext() { Object mockAppSecContext = new Object(); RequestContext mockRequestContext = mock(RequestContext.class); @@ -1001,31 +905,29 @@ void processRequestEndInvokesRequestEndedCallbackWithRequestContext() { RequestContext[] capturedContext = {null}; AgentSpan[] capturedSpan = {null}; - BiFunction> mockRequestEndedCallback = + BiFunction> requestEndedCallback = mock(BiFunction.class); doAnswer( inv -> { + callbackInvoked[0] = true; capturedContext[0] = inv.getArgument(0); capturedSpan[0] = inv.getArgument(1); - callbackInvoked[0] = true; return new Flow.ResultFlow<>(null); }) - .when(mockRequestEndedCallback) - .apply(any(), any()); + .when(requestEndedCallback) + .apply(any(RequestContext.class), any()); CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); - when(mockCallbackProvider.getCallback(EVENTS.requestEnded())) - .thenReturn(mockRequestEndedCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestEnded())).thenReturn(requestEndedCallback); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); LambdaAppSecHandler.processRequestEnd(span); - assertEquals(true, callbackInvoked[0]); + assertTrue(callbackInvoked[0]); assertEquals(mockRequestContext, capturedContext[0]); assertEquals(span, capturedSpan[0]); } @@ -1042,11 +944,10 @@ void processRequestEndHandlesNullRequestEndedCallbackGracefully() { AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); - // No exception should be thrown - should log warning but not throw LambdaAppSecHandler.processRequestEnd(span); + // no exception expected } // ============================================================================ @@ -1055,37 +956,26 @@ void processRequestEndHandlesNullRequestEndedCallbackGracefully() { @Test void mergeContextsReturnsNullWhenBothContextsAreNull() { - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, null); - - assertNull(result); + assertNull(LambdaAppSecHandler.mergeContexts(null, null)); } @Test void mergeContextsReturnsExtensionContextWhenAppSecContextIsNull() { TagContext extensionContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, null); - - assertEquals(extensionContext, result); + assertEquals(extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, null)); } @Test void mergeContextsReturnsAppSecContextWhenExtensionContextIsNull() { TagContext appSecContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, appSecContext); - - assertEquals(appSecContext, result); + assertEquals(appSecContext, LambdaAppSecHandler.mergeContexts(null, appSecContext)); } @Test void mergeContextsMergesAppSecDataIntoTagContext() { Object appSecData = new Object(); - - // Create real TagContext instances since methods are final TagContext appSecContext = new TagContext(); appSecContext.withRequestContextDataAppSec(appSecData); - TagContext extensionContext = new TagContext(); AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); @@ -1098,20 +988,16 @@ void mergeContextsMergesAppSecDataIntoTagContext() { void mergeContextsReturnsExtensionContextWhenAppSecContextIsNotTagContext() { TagContext extensionContext = mock(TagContext.class); AgentSpanContext appSecContext = mock(AgentSpanContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); - - assertEquals(extensionContext, result); + assertEquals( + extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext)); } @Test void mergeContextsReturnsExtensionContextWhenItIsNotTagContext() { AgentSpanContext extensionContext = mock(AgentSpanContext.class); TagContext appSecContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); - - assertEquals(extensionContext, result); + assertEquals( + extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext)); } // ============================================================================ @@ -1119,6 +1005,7 @@ void mergeContextsReturnsExtensionContextWhenItIsNotTagContext() { // ============================================================================ @Test + @SuppressWarnings("unchecked") void processRequestStartHandlesNullRequestStartedCallbackGracefully() { String eventJson = "{\"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -1129,232 +1016,797 @@ void processRequestStartHandlesNullRequestStartedCallbackGracefully() { AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); // Should return null when requestStarted callback is missing + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test + @SuppressWarnings("unchecked") void processRequestStartHandlesNullMethodUriCallbackGracefully() { - String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\"\n" - + " }\n" - + "}"; + String eventJson = "{\"path\": \"/test\", \"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(eventJson); Object mockAppSecContext = new Object(); - - Supplier> mockRequestStartedCallback = mock(Supplier.class); - when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); - - Function> mockHeaderDoneCallback = mock(Function.class); - when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + Supplier> requestStartedCallback = mock(Supplier.class); + when(requestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) - .thenReturn(mockRequestStartedCallback); + .thenReturn(requestStartedCallback); when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())).thenReturn(null); + Function> headerDoneCallback = mock(Function.class); + when(headerDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) - .thenReturn(mockHeaderDoneCallback); + .thenReturn(headerDoneCallback); when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())).thenReturn(null); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - assertNotNull(result); // Should continue processing even if methodUri callback is null + assertNotNull(result); assertInstanceOf(TagContext.class, result); } @Test void processRequestStartHandlesExceptionDuringJsonParsing() { - String invalidJson = "{this is not valid JSON at all"; - ByteArrayInputStream event = createInputStream(invalidJson); + ByteArrayInputStream event = createInputStream("{this is not valid JSON at all"); + assertNull(LambdaAppSecHandler.processRequestStart(event)); + } - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + @Test + void processRequestStartHandlesExceptionDuringStreamReading() { + ByteArrayInputStream mockStream = + new ByteArrayInputStream("data".getBytes()) { + @Override + public synchronized int available() { + throw new RuntimeException("Stream error"); + } + }; + assertNull(LambdaAppSecHandler.processRequestStart(mockStream)); + } + + // ============================================================================ + // processResponseData Tests β€” guard conditions + // ============================================================================ + + @Test + void processResponseDataDoesNothingWhenAppSecIsDisabled() { + ActiveSubsystems.APPSEC_ACTIVE = false; + AgentSpan span = mock(AgentSpan.class); + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": \"ok\"}"); + LambdaAppSecHandler.processResponseData(span, result); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingForNullSpan() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + LambdaAppSecHandler.processResponseData(null, result); + // no exception expected + } + + @Test + void processResponseDataDoesNothingForNonByteArrayOutputStreamResult() { + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, "string result"); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingForNullResult() { + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, null); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingWhenSpanHasNoRequestContext() { + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(null); + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + setupMockResponseCallbacks(null, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected + } + + @Test + void processResponseDataDoesNothingForOversizedResponse() { + int maxSize = Config.get().getAppSecBodyParsingSizeLimit(); + char[] chars = new char[maxSize + 1]; + java.util.Arrays.fill(chars, 'x'); + ByteArrayOutputStream result = createOutputStream(new String(chars)); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataDoesNothingForEmptyByteArrayOutputStream() { + ByteArrayOutputStream result = new ByteArrayOutputStream(); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Trigger type gating and fallback --- + + @Test + void processResponseDataSkipsNonApiGwResponseWhenTriggerTypeIsUnknown() { + LambdaAppSecHandler.setCurrentTriggerType(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN); + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, null, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertFalse(headerDoneCalled[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataAppliesFallbackForHttpTriggerWithPlainJsonResponse() { + LambdaAppSecHandler.setCurrentTriggerType(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL); + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + capturedHeaders::put, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertTrue(headerDoneCalled[0]); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("hello", ((Map) capturedBody[0]).get("result")); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataKeepsParsedHeadersAndBodyWhenStatusCodeIsZero() { + // A response that has statusCode:0 with explicit headers/body should use the parsed data, + // not discard it in favour of the plain-response fallback. + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST); + ByteArrayOutputStream result = + createOutputStream( + "{\"statusCode\": 0, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"hello\"}"); + Integer[] capturedStatus = {null}; + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + capturedHeaders::put, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); // statusCode 0 β€” responseStarted not fired + assertEquals("text/plain", capturedHeaders.get("content-type")); // parsed header kept + assertTrue(headerDoneCalled[0]); + assertEquals("hello", capturedBody[0]); // parsed body kept, not the whole envelope + } + + @Test + void processResponseDataAppliesFallbackForHttpTriggerWithNonJsonStringResponse() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST); + // JSON-encoded string (as returned by a RequestHandler) + ByteArrayOutputStream result = createOutputStream("\"Hello World!\""); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + null, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertTrue(headerDoneCalled[0]); + assertEquals("Hello World!", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataWebSocketWithStatusCodeFiresResponseStarted() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET); + // $connect handler returning a proper statusCode β€” should be treated like any API-GW response + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, null, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(200, capturedStatus[0]); + assertTrue(headerDoneCalled[0]); + } + + @Test + void processResponseDataWebSocketWithoutStatusCodeUsesFallbackWithNoStatus() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET); + // Message-route handler returning arbitrary data β€” no statusCode, fallback path + ByteArrayOutputStream result = createOutputStream("{\"message\": \"hello\"}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + null, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); // no responseStarted for status-less WebSocket messages + assertTrue(headerDoneCalled[0]); + assertInstanceOf(Map.class, capturedBody[0]); + } + + @Test + void processResponseDataSkipsNonApiGwResponseWhenTriggerTypeIsNull() { + // No processRequestStart called β€” thread-local is null β€” behaves like unknown + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Status code extraction --- + + @Test + void processResponseDataExtractsStatusCodeCorrectly() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(200, capturedStatus[0]); + } + + @Test + void processResponseDataExtractsStatusCodeAsIntegerFromDouble() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 404.0, \"body\": \"not found\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(404, capturedStatus[0]); + } + + @Test + void processResponseDataHandlesMissingStatusCode() { + ByteArrayOutputStream result = createOutputStream("{\"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataHandlesNonNumericStatusCode() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": \"bad\", \"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Header extraction --- + + @Test + void processResponseDataForwardsAllResponseHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\", \"x-custom\": \"val\", \"content-length\": \"42\", \"set-cookie\": \"a=1\"}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(4, capturedHeaders.size()); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertEquals("val", capturedHeaders.get("x-custom")); + assertEquals("42", capturedHeaders.get("content-length")); + assertEquals("a=1", capturedHeaders.get("set-cookie")); + } + + @Test + void processResponseDataLowercasesHeaderKeys() { + String json = + "{\"statusCode\": 200, \"headers\": {\"Content-Type\": \"text/html\", \"CONTENT-LENGTH\": \"10\"}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("10", capturedHeaders.get("content-length")); + } + + @Test + void processResponseDataMergesMultiValueHeadersWithSingleValueHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"content-encoding\": [\"gzip\", \"br\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("gzip, br", capturedHeaders.get("content-encoding")); + } + + @Test + void processResponseDataHandlesEmptyHeaders() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + null, capturedHeaders::put, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertTrue(capturedHeaders.isEmpty()); + assertTrue(headerDoneCalled[0]); + } - assertNull(result); // Should return null on parse error + // --- Body extraction --- + + @Test + @SuppressWarnings("unchecked") + void processResponseDataParsesJsonBody() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\"}, \"body\": \"{\\\"key\\\": \\\"value\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("value", ((Map) capturedBody[0]).get("key")); } @Test - void processRequestStartHandlesExceptionDuringStreamReading() throws IOException { - ByteArrayInputStream mockStream = mock(ByteArrayInputStream.class); - when(mockStream.available()).thenThrow(new IOException("Stream error")); + void processResponseDataHandlesNonJsonBodyAsRawString() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"plain text\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("plain text", capturedBody[0]); + } - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(mockStream); + @Test + @SuppressWarnings("unchecked") + void processResponseDataHandlesBase64EncodedBody() { + String originalBody = "{\"decoded\": \"content\"}"; + String base64Body = + Base64.getEncoder().encodeToString(originalBody.getBytes(StandardCharsets.UTF_8)); + String json = + "{\"statusCode\": 200, \"body\": \"" + base64Body + "\", \"isBase64Encoded\": true}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("content", ((Map) capturedBody[0]).get("decoded")); + } - assertNull(result); // Should return null on IO error + @Test + void processResponseDataHandlesNullBody() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": null}"); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + void processResponseDataHandlesMissingBodyField() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataAttemptsJsonParseWhenNoContentType() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 200, \"body\": \"{\\\"a\\\": 1}\"}"); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals(1.0d, ((Map) capturedBody[0]).get("a")); + } + + @Test + void processResponseDataFallsBackToRawStringWhenJsonParseFails() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 200, \"body\": \"not json {\"}"); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("not json {", capturedBody[0]); + } + + // --- Event ordering --- + + @Test + void processResponseDataFiresEventsInCorrectOrder() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\"}, \"body\": \"{\\\"k\\\": \\\"v\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + java.util.List order = new java.util.ArrayList<>(); + + AgentSpan span = + setupMockResponseCallbacks( + status -> order.add("responseStarted"), + (name, value) -> order.add("responseHeader"), + () -> order.add("responseHeaderDone"), + body -> order.add("responseBody")); + + LambdaAppSecHandler.processResponseData(span, result); + + assertEquals("responseStarted", order.get(0)); + assertTrue(order.stream().filter("responseHeader"::equals).count() >= 1); + int headerDoneIdx = order.indexOf("responseHeaderDone"); + int lastHeaderIdx = order.lastIndexOf("responseHeader"); + assertTrue(headerDoneIdx > lastHeaderIdx); + assertEquals("responseBody", order.get(order.size() - 1)); + } + + @Test + void processResponseDataHandlesInvalidBase64ResponseBodyGracefully() { + String json = + "{\"statusCode\": 200, \"body\": \"not-valid-base64!!!\", \"isBase64Encoded\": true}"; + ByteArrayOutputStream result = createOutputStream(json); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataParsesBodyAsJsonForJavascriptContentType() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/javascript\"}, \"body\": \"{\\\"key\\\": \\\"val\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("val", ((Map) capturedBody[0]).get("key")); + } + + @Test + void processResponseDataSkipsMultiValueHeadersEntryWithNonListValue() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"x-scalar\": \"not-a-list\", \"x-valid\": [\"v1\", \"v2\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("v1, v2", capturedHeaders.get("x-valid")); + assertTrue(!capturedHeaders.containsKey("x-scalar")); + } + + @Test + void processResponseDataMultiValueHeadersOverrideSingleValueHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"content-type\": [\"application/json\", \"charset=utf-8\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("application/json, charset=utf-8", capturedHeaders.get("content-type")); + } + + // --- Error handling --- + + @Test + void processResponseDataHandlesMalformedJsonResponse() { + ByteArrayOutputStream result = createOutputStream("{not valid json"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataHandlesEmptyStringResponse() { + ByteArrayOutputStream result = createOutputStream(""); + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected } // ============================================================================ - // Helper classes and methods + // processResponseData β€” null individual callback handling // ============================================================================ - private static class Callbacks { - BiConsumer onMethodUri; - BiConsumer onHeader; - BiConsumer onSocketAddress; - Consumer> onPathParams; - Consumer onBody; + @Test + void processResponseDataHandlesNullResponseHeaderDoneCallbackGracefully() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"ok\"}"; + ByteArrayOutputStream result = createOutputStream(json); - Callbacks onMethodUri(BiConsumer cb) { - this.onMethodUri = cb; - return this; - } + RequestContext mockRequestContext = mock(RequestContext.class); + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(mockRequestContext); - Callbacks onHeader(BiConsumer cb) { - this.onHeader = cb; - return this; - } + CallbackProvider cbp = mock(CallbackProvider.class); + when(cbp.getCallback(EVENTS.responseStarted())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseHeader())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseHeaderDone())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseBody())).thenReturn(null); - Callbacks onSocketAddress(BiConsumer cb) { - this.onSocketAddress = cb; - return this; - } + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)).thenReturn(cbp); + AgentTracer.forceRegister(mockTracer); - Callbacks onPathParams(Consumer> cb) { - this.onPathParams = cb; - return this; - } + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected β€” all null callbacks must be silently skipped + } - Callbacks onBody(Consumer cb) { - this.onBody = cb; - return this; - } + // ============================================================================ + // extractResponseData Unit Tests + // ============================================================================ + + @Test + void extractResponseDataReturnsNullForMalformedJson() { + assertNull(LambdaAppSecHandler.extractResponseData("{bad json")); } - private static Map mapOf(Object... keysAndValues) { - Map map = new LinkedHashMap<>(); - for (int i = 0; i < keysAndValues.length; i += 2) { - map.put((String) keysAndValues[i], keysAndValues[i + 1]); - } - return map; + @Test + void extractResponseDataReturnsNullForNullJsonParseResult() { + assertNull(LambdaAppSecHandler.extractResponseData("null")); } - private ByteArrayInputStream createInputStream(String json) { + @Test + void extractResponseDataReturnsNullForEmptyString() { + assertNull(LambdaAppSecHandler.extractResponseData("")); + } + + // ============================================================================ + // Helper Methods + // ============================================================================ + + private static ByteArrayInputStream createInputStream(String json) { return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); } - private void setupMockCallbacks(Callbacks callbacks) { - Object mockAppSecContext = new Object(); + private static ByteArrayOutputStream createOutputStream(String json) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + baos.write(json.getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + throw new RuntimeException(e); + } + return baos; + } - Supplier> mockRequestStartedCallback = mock(Supplier.class); - when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); + @FunctionalInterface + interface MethodUriCapture { + void accept(String method, URIDataAdapter uri); + } + + @FunctionalInterface + interface IpPortCapture { + void accept(String ip, int port); + } - TriFunction> mockMethodUriCallback = null; - if (callbacks.onMethodUri != null) { - mockMethodUriCallback = mock(TriFunction.class); - BiConsumer methodUriCb = callbacks.onMethodUri; + @SuppressWarnings("unchecked") + private void setupMockCallbacks( + MethodUriCapture onMethodUri, + BiConsumer onHeader, + IpPortCapture onSocketAddress, + Consumer> onPathParams, + Consumer onBody) { + + Object mockAppSecContext = new Object(); + Supplier> requestStartedCallback = mock(Supplier.class); + when(requestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); + + TriFunction> methodUriCallback = null; + if (onMethodUri != null) { + methodUriCallback = mock(TriFunction.class); + MethodUriCapture capture = onMethodUri; doAnswer( inv -> { - String method = inv.getArgument(1); - URIDataAdapter uri = inv.getArgument(2); - methodUriCb.accept(method, uri); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1), inv.getArgument(2)); + return Flow.ResultFlow.empty(); }) - .when(mockMethodUriCallback) - .apply(any(), any(), any()); + .when(methodUriCallback) + .apply(any(), anyString(), any(URIDataAdapter.class)); } - TriConsumer mockHeaderCallback = null; - if (callbacks.onHeader != null) { - mockHeaderCallback = mock(TriConsumer.class); - BiConsumer headerCb = callbacks.onHeader; + TriConsumer headerCallback = null; + if (onHeader != null) { + headerCallback = mock(TriConsumer.class); + BiConsumer capture = onHeader; doAnswer( inv -> { - String name = inv.getArgument(1); - String value = inv.getArgument(2); - headerCb.accept(name, value); + capture.accept(inv.getArgument(1), inv.getArgument(2)); return null; }) - .when(mockHeaderCallback) - .accept(any(), any(), any()); + .when(headerCallback) + .accept(any(), anyString(), anyString()); } - TriFunction> mockSocketAddressCallback = null; - if (callbacks.onSocketAddress != null) { - mockSocketAddressCallback = mock(TriFunction.class); - BiConsumer socketCb = callbacks.onSocketAddress; + TriFunction> socketAddressCallback = null; + if (onSocketAddress != null) { + socketAddressCallback = mock(TriFunction.class); + IpPortCapture capture = onSocketAddress; doAnswer( inv -> { - String ip = inv.getArgument(1); - Integer port = inv.getArgument(2); - socketCb.accept(ip, port); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1), (Integer) inv.getArgument(2)); + return Flow.ResultFlow.empty(); }) - .when(mockSocketAddressCallback) - .apply(any(), any(), any()); + .when(socketAddressCallback) + .apply(any(), anyString(), anyInt()); } - Function> mockHeaderDoneCallback = mock(Function.class); - when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + Function> headerDoneCallback = mock(Function.class); + when(headerDoneCallback.apply(any())).thenReturn(Flow.ResultFlow.empty()); - BiFunction, Flow> mockPathParamsCallback = null; - if (callbacks.onPathParams != null) { - mockPathParamsCallback = mock(BiFunction.class); - Consumer> pathParamsCb = callbacks.onPathParams; + BiFunction, Flow> pathParamsCallback = null; + if (onPathParams != null) { + pathParamsCallback = mock(BiFunction.class); + Consumer> capture = onPathParams; doAnswer( inv -> { - Map params = inv.getArgument(1); - pathParamsCb.accept(params); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1)); + return Flow.ResultFlow.empty(); }) - .when(mockPathParamsCallback) - .apply(any(), any()); + .when(pathParamsCallback) + .apply(any(), any(Map.class)); } - BiFunction> mockBodyCallback = null; - if (callbacks.onBody != null) { - mockBodyCallback = mock(BiFunction.class); - Consumer bodyCb = callbacks.onBody; + BiFunction> bodyCallback = null; + if (onBody != null) { + bodyCallback = mock(BiFunction.class); + Consumer capture = onBody; doAnswer( inv -> { - Object body = inv.getArgument(1); - bodyCb.accept(body); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1)); + return Flow.ResultFlow.empty(); }) - .when(mockBodyCallback) + .when(bodyCallback) .apply(any(), any()); } CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) - .thenReturn(mockRequestStartedCallback); + .thenReturn(requestStartedCallback); when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())) - .thenReturn(mockMethodUriCallback); - when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(mockHeaderCallback); + .thenReturn(methodUriCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(headerCallback); when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())) - .thenReturn(mockSocketAddressCallback); + .thenReturn(socketAddressCallback); when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) - .thenReturn(mockHeaderDoneCallback); + .thenReturn(headerDoneCallback); when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())) - .thenReturn(mockPathParamsCallback); - when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())) - .thenReturn(mockBodyCallback); + .thenReturn(pathParamsCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())).thenReturn(bodyCallback); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); } - private static String repeatChar(char ch, int count) { - char[] chars = new char[count]; - Arrays.fill(chars, ch); - return new String(chars); + @SuppressWarnings("unchecked") + private AgentSpan setupMockResponseCallbacks( + Consumer onResponseStarted, + BiConsumer onResponseHeader, + Runnable onResponseHeaderDone, + Consumer onResponseBody) { + + RequestContext mockRequestContext = mock(RequestContext.class); + AgentSpan mockSpan = mock(AgentSpan.class); + when(mockSpan.getRequestContext()).thenReturn(mockRequestContext); + + BiFunction> responseStartedCb = null; + if (onResponseStarted != null) { + responseStartedCb = mock(BiFunction.class); + Consumer capture = onResponseStarted; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1)); + return new Flow.ResultFlow<>(null); + }) + .when(responseStartedCb) + .apply(any(RequestContext.class), anyInt()); + } + + TriConsumer responseHeaderCb = null; + if (onResponseHeader != null) { + responseHeaderCb = mock(TriConsumer.class); + BiConsumer capture = onResponseHeader; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1), inv.getArgument(2)); + return null; + }) + .when(responseHeaderCb) + .accept(any(), anyString(), anyString()); + } + + Function> responseHeaderDoneCb = mock(Function.class); + if (onResponseHeaderDone != null) { + Runnable capture = onResponseHeaderDone; + doAnswer( + inv -> { + capture.run(); + return new Flow.ResultFlow<>(null); + }) + .when(responseHeaderDoneCb) + .apply(any(RequestContext.class)); + } else { + when(responseHeaderDoneCb.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + } + + BiFunction> responseBodyCb = null; + if (onResponseBody != null) { + responseBodyCb = mock(BiFunction.class); + Consumer capture = onResponseBody; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1)); + return new Flow.ResultFlow<>(null); + }) + .when(responseBodyCb) + .apply(any(RequestContext.class), any()); + } + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.responseStarted())).thenReturn(responseStartedCb); + when(mockCallbackProvider.getCallback(EVENTS.responseHeader())).thenReturn(responseHeaderCb); + when(mockCallbackProvider.getCallback(EVENTS.responseHeaderDone())) + .thenReturn(responseHeaderDoneCb); + when(mockCallbackProvider.getCallback(EVENTS.responseBody())).thenReturn(responseBodyCb); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + AgentTracer.forceRegister(mockTracer); + + return mockSpan; } } From 6314204d867b5e7647d45d3bb00766dd3813c622 Mon Sep 17 00:00:00 2001 From: Clara Poncet Date: Fri, 26 Jun 2026 17:51:17 +0200 Subject: [PATCH 12/15] PR review fixes: revert broken startSpan calls + test quality cleanup --- .../lambda/LambdaHandlerInstrumentation.java | 13 +- .../trace/lambda/LambdaAppSecHandlerTest.java | 463 +++++++++--------- 2 files changed, 255 insertions(+), 221 deletions(-) diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index 2c1370a3a80..6f391c63e80 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -94,9 +94,9 @@ static AgentScope enter( AgentSpanContext lambdaContext = AgentTracer.get().notifyLambdaStart(in, lambdaRequestId); final AgentSpan span; if (null == lambdaContext) { - span = startSpan(INVOCATION_SPAN_NAME); + span = startSpan("java-aws-sdk", INVOCATION_SPAN_NAME); } else { - span = startSpan(INVOCATION_SPAN_NAME, lambdaContext); + span = startSpan("java-aws-sdk", INVOCATION_SPAN_NAME, lambdaContext); } span.setSpanType(InternalSpanTypes.SERVERLESS); span.setTag("request_id", lambdaRequestId); @@ -132,6 +132,15 @@ static void exit( // (filter_span_from_lambda_library_or_runtime in // bottlecap/src/traces/trace_processor.rs, which compares // span.resource == "dd-tracer-serverless-span") drops the placeholder. + // Other instrumentation (HTTP/JAX-RS) may have overwritten it with the + // route ("POST /") during the invocation, in which case the extension + // would fail to dedup, leading to the placeholder leaking to the backend + // with parent_id=0 and detaching the inferred apigateway root from the + // rest of the trace. + // Use TAG_INTERCEPTOR priority because DDSpanContext.setResourceName + // ignores writes whose priority is below the current resource priority, + // and the HTTP/JAX-RS instrumentation will already have written + // HTTP_FRAMEWORK_ROUTE (3) by this point. span.setResourceName(INVOCATION_SPAN_NAME, ResourceNamePriorities.TAG_INTERCEPTOR); span.finish(); AgentTracer.get().notifyExtensionEnd(span, result, null != throwable, lambdaRequestId); diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java index 4b39677f046..0e6526aaf2b 100644 --- a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import datadog.trace.api.Config; @@ -35,8 +36,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Base64; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; @@ -136,130 +139,133 @@ void streamCanBeReadMultipleTimesAfterProcessing() throws IOException { @Test void detectsApiGatewayV1RestTriggerType() { - Map event = new HashMap<>(); - Map requestContext = new HashMap<>(); - requestContext.put("httpMethod", "GET"); - requestContext.put("requestId", "abc123"); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf("requestContext", mapOf("httpMethod", "GET", "requestId", "abc123")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST, triggerType); } @Test void detectsApiGatewayV2HttpTriggerType() { - Map http = new HashMap<>(); - http.put("method", "POST"); - http.put("path", "/api"); - Map requestContext = new HashMap<>(); - requestContext.put("http", http); - requestContext.put("domainName", "api.example.com"); - Map event = new HashMap<>(); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf( + "requestContext", + mapOf( + "http", mapOf("method", "POST", "path", "/api"), "domainName", "api.example.com")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP, triggerType); } @Test void detectsLambdaFunctionUrlTriggerType() { - Map http = new HashMap<>(); - http.put("method", "GET"); - http.put("path", "/"); - Map requestContext = new HashMap<>(); - requestContext.put("http", http); - requestContext.put("domainName", "xyz123.lambda-url.us-east-1.on.aws"); - Map event = new HashMap<>(); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf( + "requestContext", + mapOf( + "http", + mapOf("method", "GET", "path", "/"), + "domainName", + "xyz123.lambda-url.us-east-1.on.aws")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); } @Test void detectsAlbTriggerTypeWithoutMultiValueHeaders() { - Map elb = new HashMap<>(); - elb.put("targetGroupArn", "arn:aws:..."); - Map requestContext = new HashMap<>(); - requestContext.put("elb", elb); - Map event = new HashMap<>(); - event.put("httpMethod", "GET"); - event.put("path", "/"); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.ALB, LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf( + "httpMethod", + "GET", + "path", + "/", + "requestContext", + mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB, triggerType); } @Test void detectsAlbTriggerTypeWithMultiValueHeaders() { - Map elb = new HashMap<>(); - elb.put("targetGroupArn", "arn:aws:..."); - Map requestContext = new HashMap<>(); - requestContext.put("elb", elb); - Map event = new HashMap<>(); - event.put("httpMethod", "GET"); - event.put("path", "/"); - event.put("multiValueHeaders", new HashMap<>()); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf( + "httpMethod", + "GET", + "path", + "/", + "multiValueHeaders", + mapOf("accept", Arrays.asList("text/html", "application/json")), + "requestContext", + mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE, triggerType); } @Test void detectsWebSocketTriggerTypeWithRouteKey() { - Map requestContext = new HashMap<>(); - requestContext.put("connectionId", "conn-123"); - requestContext.put("routeKey", "$connect"); - Map event = new HashMap<>(); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf("requestContext", mapOf("connectionId", "conn-123", "routeKey", "$connect")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); } @Test void detectsWebSocketTriggerTypeWithEventType() { - Map requestContext = new HashMap<>(); - requestContext.put("connectionId", "conn-456"); - requestContext.put("eventType", "CONNECT"); - Map event = new HashMap<>(); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf("requestContext", mapOf("connectionId", "conn-456", "eventType", "CONNECT")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); } @Test void detectsUnknownTriggerTypeForUnrecognizedEvents() { - Map event = new HashMap<>(); - event.put("someUnknownField", "value"); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = mapOf("someUnknownField", "value"); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); } @Test void detectsUnknownTriggerTypeForEmptyRequestContext() { - Map event = new HashMap<>(); - event.put("requestContext", new HashMap<>()); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = mapOf("requestContext", mapOf()); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); } @Test void detectsLambdaUrlWhenHttpPresentButNoDomainName() { - Map http = new HashMap<>(); - http.put("method", "GET"); - http.put("path", "/ambiguous"); - Map requestContext = new HashMap<>(); - requestContext.put("http", http); - Map event = new HashMap<>(); - event.put("requestContext", requestContext); - assertEquals( - LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, - LambdaAppSecHandler.detectTriggerType(event)); + Map event = + mapOf("requestContext", mapOf("http", mapOf("method", "GET", "path", "/ambiguous"))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); } // ============================================================================ @@ -293,17 +299,20 @@ void extractsApiGatewayV1RestDataCorrectly() { Object[] capturedBody = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - capturedHeaders::put, - (ip, port) -> { - capturedSourceIp[0] = ip; - capturedSourcePort[0] = port; - }, - params -> capturedPathParams[0] = params, - body -> capturedBody[0] = body); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader(capturedHeaders::put) + .onSocketAddress( + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }) + .onPathParams(params -> capturedPathParams[0] = params) + .onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -345,17 +354,19 @@ void extractsApiGatewayV2HttpDataCorrectly() { Map[] capturedPathParams = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - capturedHeaders::put, - (ip, port) -> { - capturedSourceIp[0] = ip; - capturedSourcePort[0] = port; - }, - params -> capturedPathParams[0] = params, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader(capturedHeaders::put) + .onSocketAddress( + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }) + .onPathParams(params -> capturedPathParams[0] = params)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -388,14 +399,12 @@ void extractsLambdaFunctionUrlDataCorrectly() { String[] capturedPath = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - null, - null, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + })); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -422,14 +431,13 @@ void extractsAlbDataCorrectly() { String[] capturedSourceIp = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - null, - (ip, port) -> capturedSourceIp[0] = ip, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -452,7 +460,7 @@ void extractsAlbMultiValueHeadersCorrectly() { Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(null, capturedHeaders::put, null, null, null); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -474,7 +482,7 @@ void handlesMultiValueHeadersWithEmptyList() { Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(null, capturedHeaders::put, null, null, null); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -500,14 +508,13 @@ void extractsWebSocketDataCorrectly() { String[] capturedSourceIp = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - null, - (ip, port) -> capturedSourceIp[0] = ip, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -533,7 +540,7 @@ void handlesBase64EncodedBodyCorrectly() { Object[] capturedBody = {null}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -548,7 +555,7 @@ void handlesNullBodyCorrectly() { String[] capturedBody = {"NOT_CALLED"}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = String.valueOf(body)); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -563,7 +570,7 @@ void handlesEmptyBodyCorrectly() { Object[] capturedBody = {null}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -584,14 +591,12 @@ void handlesPathWithQueryStringCorrectly() { String[] capturedQuery = {null}; setupMockCallbacks( - (method, uri) -> { - capturedPath[0] = uri.path(); - capturedQuery[0] = uri.query(); - }, - null, - null, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedPath[0] = uri.path(); + capturedQuery[0] = uri.query(); + })); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -614,14 +619,12 @@ void extractsSchemeAndPortFromXForwardedHeaders() { int[] capturedPort = {-1}; setupMockCallbacks( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - }, - null, - null, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -644,14 +647,12 @@ void fallsBackToHttps443WhenXForwardedHeadersAreAbsent() { int[] capturedPort = {-1}; setupMockCallbacks( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - }, - null, - null, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -674,14 +675,12 @@ void handlesInvalidXForwardedPortGracefully() { int[] capturedPort = {-1}; setupMockCallbacks( - (method, uri) -> { - capturedScheme[0] = uri.scheme(); - capturedPort[0] = uri.port(); - }, - null, - null, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -702,7 +701,7 @@ void handlesInvalidBase64BodyGracefully() { String[] capturedBody = {"NOT_CALLED"}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = String.valueOf(body)); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -725,7 +724,7 @@ void handlesBase64DecodedEmptyStringBody() { Object[] capturedBody = {null}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -745,7 +744,7 @@ void handlesBodyWithSpecialCharacters() { Object[] capturedBody = {null}; - setupMockCallbacks(null, null, null, null, body -> capturedBody[0] = body); + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -776,14 +775,14 @@ void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { String[] capturedSourceIp = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - capturedHeaders::put, - (ip, port) -> capturedSourceIp[0] = ip, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader(capturedHeaders::put) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -809,14 +808,13 @@ void extractsDataFromUnknownTriggerWithHttpInRequestContext() { String[] capturedSourceIp = {null}; setupMockCallbacks( - (method, uri) -> { - capturedMethod[0] = method; - capturedPath[0] = uri.path(); - }, - null, - (ip, port) -> capturedSourceIp[0] = ip, - null, - null); + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -838,7 +836,7 @@ void handlesCookiesMergingWithExistingCookieHeader() { Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(null, capturedHeaders::put, null, null, null); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -858,12 +856,12 @@ void handlesEmptyCookiesArrayCorrectly() { Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(null, capturedHeaders::put, null, null, null); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertTrue(!capturedHeaders.containsKey("cookie")); + assertFalse(capturedHeaders.containsKey("cookie")); } // ============================================================================ @@ -880,8 +878,10 @@ void processRequestEndDoesNothingWhenSpanIsNull() { void processRequestEndDoesNothingWhenAppSecIsDisabled() { ActiveSubsystems.APPSEC_ACTIVE = false; AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processRequestEnd(span); - verify(span, never()).getRequestContext(); + + verifyNoInteractions(span); } @Test @@ -1522,7 +1522,7 @@ void processResponseDataSkipsMultiValueHeadersEntryWithNonListValue() { LambdaAppSecHandler.processResponseData(span, result); assertEquals("text/html", capturedHeaders.get("content-type")); assertEquals("v1, v2", capturedHeaders.get("x-valid")); - assertTrue(!capturedHeaders.containsKey("x-scalar")); + assertFalse(capturedHeaders.containsKey("x-scalar")); } @Test @@ -1621,32 +1621,49 @@ private static ByteArrayOutputStream createOutputStream(String json) { return baos; } - @FunctionalInterface - interface MethodUriCapture { - void accept(String method, URIDataAdapter uri); - } + private static class Callbacks { + BiConsumer onMethodUri; + BiConsumer onHeader; + BiConsumer onSocketAddress; + Consumer> onPathParams; + Consumer onBody; + + Callbacks onMethodUri(BiConsumer cb) { + this.onMethodUri = cb; + return this; + } - @FunctionalInterface - interface IpPortCapture { - void accept(String ip, int port); + Callbacks onHeader(BiConsumer cb) { + this.onHeader = cb; + return this; + } + + Callbacks onSocketAddress(BiConsumer cb) { + this.onSocketAddress = cb; + return this; + } + + Callbacks onPathParams(Consumer> cb) { + this.onPathParams = cb; + return this; + } + + Callbacks onBody(Consumer cb) { + this.onBody = cb; + return this; + } } @SuppressWarnings("unchecked") - private void setupMockCallbacks( - MethodUriCapture onMethodUri, - BiConsumer onHeader, - IpPortCapture onSocketAddress, - Consumer> onPathParams, - Consumer onBody) { - + private void setupMockCallbacks(Callbacks callbacks) { Object mockAppSecContext = new Object(); Supplier> requestStartedCallback = mock(Supplier.class); when(requestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); TriFunction> methodUriCallback = null; - if (onMethodUri != null) { + if (callbacks.onMethodUri != null) { methodUriCallback = mock(TriFunction.class); - MethodUriCapture capture = onMethodUri; + BiConsumer capture = callbacks.onMethodUri; doAnswer( inv -> { capture.accept(inv.getArgument(1), inv.getArgument(2)); @@ -1657,9 +1674,9 @@ private void setupMockCallbacks( } TriConsumer headerCallback = null; - if (onHeader != null) { + if (callbacks.onHeader != null) { headerCallback = mock(TriConsumer.class); - BiConsumer capture = onHeader; + BiConsumer capture = callbacks.onHeader; doAnswer( inv -> { capture.accept(inv.getArgument(1), inv.getArgument(2)); @@ -1670,9 +1687,9 @@ private void setupMockCallbacks( } TriFunction> socketAddressCallback = null; - if (onSocketAddress != null) { + if (callbacks.onSocketAddress != null) { socketAddressCallback = mock(TriFunction.class); - IpPortCapture capture = onSocketAddress; + BiConsumer capture = callbacks.onSocketAddress; doAnswer( inv -> { capture.accept(inv.getArgument(1), (Integer) inv.getArgument(2)); @@ -1686,9 +1703,9 @@ private void setupMockCallbacks( when(headerDoneCallback.apply(any())).thenReturn(Flow.ResultFlow.empty()); BiFunction, Flow> pathParamsCallback = null; - if (onPathParams != null) { + if (callbacks.onPathParams != null) { pathParamsCallback = mock(BiFunction.class); - Consumer> capture = onPathParams; + Consumer> capture = callbacks.onPathParams; doAnswer( inv -> { capture.accept(inv.getArgument(1)); @@ -1699,9 +1716,9 @@ private void setupMockCallbacks( } BiFunction> bodyCallback = null; - if (onBody != null) { + if (callbacks.onBody != null) { bodyCallback = mock(BiFunction.class); - Consumer capture = onBody; + Consumer capture = callbacks.onBody; doAnswer( inv -> { capture.accept(inv.getArgument(1)); @@ -1731,6 +1748,14 @@ private void setupMockCallbacks( AgentTracer.forceRegister(mockTracer); } + private static Map mapOf(Object... keysAndValues) { + Map map = new LinkedHashMap<>(); + for (int i = 0; i < keysAndValues.length; i += 2) { + map.put((String) keysAndValues[i], keysAndValues[i + 1]); + } + return map; + } + @SuppressWarnings("unchecked") private AgentSpan setupMockResponseCallbacks( Consumer onResponseStarted, From 3d2a12a4fbd5f7a1815663808cb5e7061acf7def Mon Sep 17 00:00:00 2001 From: Clara Poncet Date: Fri, 26 Jun 2026 18:00:44 +0200 Subject: [PATCH 13/15] spotlessApply --- .../LambdaHandlerInstrumentationTest.java | 30 +++++++------------ .../trace/lambda/LambdaAppSecHandler.java | 3 +- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java index 9ab0606f2d2..9c55c8d5047 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java @@ -165,8 +165,7 @@ void testLambdaStreamingHandler() throws IOException { ByteArrayOutputStream output = new ByteArrayOutputStream(); new HandlerStreaming().handleRequest(input, output, newContext()); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -229,8 +228,7 @@ void appSecCallbacksAreInvokedForApiGatewayV1Event() throws IOException { assertEquals("application/json", capturedHeaders.get("content-type")); assertTrue(capturedBody instanceof Map); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -264,8 +262,7 @@ void appSecCallbacksAreInvokedForApiGatewayV2HttpEvent() throws IOException { assertEquals("session=abc123", capturedHeaders.get("cookie")); assertTrue(capturedBody instanceof Map); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -286,8 +283,7 @@ void appSecCallbacksAreNotInvokedWhenAppSecIsDisabled() throws IOException { assertNull(capturedMethod); assertFalse(appSecEnded); assertNull(capturedResponseStatus); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -314,8 +310,7 @@ void responseCallbacksAreInvokedForJsonEncodedResponse() throws IOException { assertEquals("ok", ((Map) capturedResponseBody).get("result")); assertTrue(responseHeaderDoneCalled); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -337,8 +332,7 @@ void responseCallbacksReceiveCorrectDataFor404Response() throws IOException { assertEquals("text/html", capturedResponseHeaders.get("content-type")); assertEquals("Not Found", capturedResponseBody); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -371,8 +365,7 @@ void responseCallbacksApplyFallbackForLambdaUrlWithNonApiGatewayResponse() throw assertEquals("hello", ((Map) capturedResponseBody).get("result")); assertTrue(responseHeaderDoneCalled); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -388,8 +381,7 @@ void responseCallbacksSkipNonApiGatewayResponseForNonHttpEvent() throws IOExcept assertNull(capturedResponseBody); assertFalse(responseHeaderDoneCalled); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -421,8 +413,7 @@ void responseAndRequestCallbacksAreBothInvoked() throws IOException { assertTrue(capturedResponseBody instanceof Map); assertTrue(appSecEnded); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test @@ -485,8 +476,7 @@ void responseCallbacksFireBeforeRequestEnded() throws IOException { assertTrue(callOrder.indexOf("responseStarted") < callOrder.indexOf("requestEnded")); assertTrue(callOrder.indexOf("responseHeaderDone") < callOrder.indexOf("requestEnded")); assertTrue(callOrder.indexOf("responseBody") < callOrder.indexOf("requestEnded")); - assertTraces( - trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); } @Test diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 1c0b1112df2..4183ed11328 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -171,7 +171,8 @@ public static void processResponseData(AgentSpan span, Object result) { Collections.singletonMap("content-type", fallbackContentType); responseData = new LambdaResponseData(0, fallbackHeaders, fallbackBody); } - // else: responseData has explicit headers/body fields β€” keep them, just skip responseStarted + // else: responseData has explicit headers/body fields β€” keep them, just skip + // responseStarted // (statusCode remains 0, so the responseStarted guard below will not fire). } From a440114a22bac02b34798dd71dffe7b8e821d6fe Mon Sep 17 00:00:00 2001 From: Clara Poncet Date: Mon, 29 Jun 2026 10:56:13 +0200 Subject: [PATCH 14/15] Normalise content-type casing before JSON detection --- .../java/datadog/trace/lambda/LambdaAppSecHandler.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 4183ed11328..4454baca348 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -285,9 +285,11 @@ static LambdaResponseData extractResponseData(String json) { String contentType = headers.get("content-type"); // If JSON content-type or unknown, attempt JSON parsing - if (contentType == null - || contentType.contains("json") - || contentType.contains("javascript")) { + // Normalise casing: media type tokens are case-insensitive per RFC 7231 + String contentTypeLower = contentType == null ? null : contentType.toLowerCase(Locale.ROOT); + if (contentTypeLower == null + || contentTypeLower.contains("json") + || contentTypeLower.contains("javascript")) { Object parsed = parseBodyAsJson(bodyString); body = parsed != null ? parsed : bodyString; } else { From 868f9a139e8876376ecf3901ba4e94599bd8250a Mon Sep 17 00:00:00 2001 From: Clara Poncet Date: Mon, 29 Jun 2026 11:03:20 +0200 Subject: [PATCH 15/15] spotlessApply --- .../main/java/datadog/trace/lambda/LambdaAppSecHandler.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 4454baca348..5198b63e294 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -286,7 +286,8 @@ static LambdaResponseData extractResponseData(String json) { // If JSON content-type or unknown, attempt JSON parsing // Normalise casing: media type tokens are case-insensitive per RFC 7231 - String contentTypeLower = contentType == null ? null : contentType.toLowerCase(Locale.ROOT); + String contentTypeLower = + contentType == null ? null : contentType.toLowerCase(Locale.ROOT); if (contentTypeLower == null || contentTypeLower.contains("json") || contentTypeLower.contains("javascript")) {