diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index 610b24b2f31..6f391c63e80 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -126,7 +126,7 @@ static void exit( } String lambdaRequestId = awsContext.getAwsRequestId(); - AgentTracer.get().notifyAppSecEnd(span); + AgentTracer.get().notifyAppSecEnd(span, result); // Force the resource name back to the literal placeholder marker right // before finish so that the Datadog Lambda Extension's filter // (filter_span_from_lambda_library_or_runtime in diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy deleted file mode 100644 index ed1152ea1aa..00000000000 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy +++ /dev/null @@ -1,294 +0,0 @@ -import static datadog.trace.api.gateway.Events.EVENTS - -import datadog.trace.agent.test.naming.VersionedNamingTestBase -import datadog.trace.api.DDSpanTypes -import datadog.trace.api.function.TriConsumer -import datadog.trace.api.function.TriFunction -import datadog.trace.api.gateway.Flow -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.gateway.RequestContextSlot -import datadog.trace.bootstrap.ActiveSubsystems -import datadog.trace.bootstrap.instrumentation.api.AgentTracer -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter -import com.amazonaws.services.lambda.runtime.Context -import java.nio.charset.StandardCharsets -import java.util.function.BiFunction -import java.util.function.Function -import java.util.function.Supplier - -abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase { - def requestId = "test-request-id" - - // Must set this env var before the Datadog integration is initialized. - // If present at load time, the integration auto-enables. - static { - environmentVariables.set("_HANDLER", "Handler") - } - - @Override - String service() { - null - } - - def ig - def appSecStarted = false - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedBody = null - def appSecEnded = false - - def setup() { - ig = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC) - ActiveSubsystems.APPSEC_ACTIVE = true - appSecStarted = false - capturedMethod = null - capturedPath = null - capturedHeaders = [:] - capturedBody = null - appSecEnded = false - ig.registerCallback(EVENTS.requestStarted(), { - appSecStarted = true - new Flow.ResultFlow(new Object()) - } as Supplier) - ig.registerCallback(EVENTS.requestMethodUriRaw(), { RequestContext ctx, String method, URIDataAdapter uri -> - capturedMethod = method - capturedPath = uri.path() - Flow.ResultFlow.empty() - } as TriFunction) - ig.registerCallback(EVENTS.requestHeader(), { RequestContext ctx, String name, String value -> - capturedHeaders[name] = value - } as TriConsumer) - ig.registerCallback(EVENTS.requestHeaderDone(), { RequestContext ctx -> - Flow.ResultFlow.empty() - } as Function) - ig.registerCallback(EVENTS.requestBodyProcessed(), { RequestContext ctx, Object body -> - capturedBody = body - Flow.ResultFlow.empty() - } as BiFunction) - ig.registerCallback(EVENTS.requestEnded(), { RequestContext ctx, Object spanInfo -> - appSecEnded = true - Flow.ResultFlow.empty() - } as BiFunction) - } - - def cleanup() { - ig.reset() - ActiveSubsystems.APPSEC_ACTIVE = false - } - - def "test lambda streaming handler"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "serverless invocation span resource reset after simulated HTTP framework overwrite"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreamingSimulatesHttpFrameworkResource().handleRequest(input, output, ctx) - - then: - assertTraces(1) { - trace(1) { - span { - operationName operation() - resourceName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "test streaming handler with error"() { - when: - def input = new ByteArrayInputStream(StandardCharsets.UTF_8.encode("Hello").array()) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { - getAwsRequestId() >> requestId - } - new HandlerStreamingWithError().handleRequest(input, output, ctx) - - then: - thrown(Error) - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored true - tags { - tag "request_id", requestId - tag "error.type", "java.lang.Error" - tag "error.message", "Some error" - tag "error.stack", String - tag "language", "jvm" - tag "process_id", Long - tag "runtime-id", String - tag "thread.id", Long - tag "thread.name", String - tag "_dd.profiling.ctx", "test" - tag "_dd.profiling.enabled", 0 - tag "_dd.agent_psr", 1.0 - tag "_dd.tracer_host", String - tag "_sample_rate", 1 - tag "_dd.trace_span_attribute_schema", { it != null } - } - } - } - } - } - - def "appsec callbacks are invoked for API Gateway v1 event"() { - given: - def eventJson = """{ - "path": "/api/users/123", - "headers": {"content-type": "application/json", "x-forwarded-for": "203.0.113.1"}, - "body": "{\\"key\\": \\"value\\"}", - "requestContext": { - "httpMethod": "GET", - "requestId": "req-abc", - "identity": {"sourceIp": "203.0.113.1"} - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - appSecStarted - capturedMethod == "GET" - capturedPath == "/api/users/123" - capturedHeaders["content-type"] == "application/json" - capturedBody instanceof Map - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "appsec callbacks are invoked for API Gateway v2 HTTP event"() { - given: - def eventJson = """{ - "version": "2.0", - "headers": {"content-type": "application/json", "accept": "application/json"}, - "cookies": ["session=abc123"], - "body": "{\\"key\\": \\"value\\"}", - "requestContext": { - "http": { - "method": "POST", - "path": "/api/items", - "sourceIp": "198.51.100.1" - }, - "domainName": "api.example.com" - } - }""" - - when: - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - appSecStarted - capturedMethod == "POST" - capturedPath == "/api/items" - capturedHeaders["content-type"] == "application/json" - capturedHeaders["cookie"] == "session=abc123" - capturedBody instanceof Map - appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } - - def "appsec callbacks are not invoked when appsec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - - when: - def eventJson = """{ - "path": "/api/test", - "requestContext": {"httpMethod": "GET", "requestId": "req-xyz"} - }""" - def input = new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)) - def output = new ByteArrayOutputStream() - def ctx = Stub(Context) { getAwsRequestId() >> requestId } - new HandlerStreaming().handleRequest(input, output, ctx) - - then: - !appSecStarted - capturedMethod == null - !appSecEnded - assertTraces(1) { - trace(1) { - span { - operationName operation() - spanType DDSpanTypes.SERVERLESS - errored false - } - } - } - } -} - - -class LambdaHandlerInstrumentationV0Test extends LambdaHandlerInstrumentationTest { - @Override - int version() { - 0 - } - - @Override - String operation() { - "dd-tracer-serverless-span" - } -} - -class LambdaHandlerInstrumentationV1ForkedTest extends LambdaHandlerInstrumentationTest { - @Override - int version() { - 1 - } - - @Override - String operation() { - "aws.lambda.invoke" - } -} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreaming.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreaming.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreaming.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreaming.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingSimulatesHttpFrameworkResource.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingSimulatesHttpFrameworkResource.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingSimulatesHttpFrameworkResource.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingSimulatesHttpFrameworkResource.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWith404Response.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWith404Response.java new file mode 100644 index 00000000000..540b245e802 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWith404Response.java @@ -0,0 +1,24 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; + +public class HandlerStreamingWith404Response implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + PrintWriter writer = + new PrintWriter( + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8))); + writer.write( + "{\"statusCode\": 404, " + + "\"headers\": {\"content-type\": \"text/html\"}, " + + "\"body\": \"Not Found\"}"); + writer.close(); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithApiGwResponse.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithApiGwResponse.java new file mode 100644 index 00000000000..7b91f1f3510 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithApiGwResponse.java @@ -0,0 +1,24 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; + +public class HandlerStreamingWithApiGwResponse implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + PrintWriter writer = + new PrintWriter( + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8))); + writer.write( + "{\"statusCode\": 200, " + + "\"headers\": {\"content-type\": \"application/json\", \"x-custom\": \"custom-val\"}, " + + "\"body\": \"{\\\"result\\\": \\\"ok\\\"}\"}"); + writer.close(); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithError.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithError.java similarity index 100% rename from dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/HandlerStreamingWithError.java rename to dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithError.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java new file mode 100644 index 00000000000..a34a2623770 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/HandlerStreamingWithRawJson.java @@ -0,0 +1,15 @@ +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; + +/** Writes valid JSON that is not in API Gateway response format (no statusCode/headers/body). */ +public class HandlerStreamingWithRawJson implements RequestStreamHandler { + @Override + public void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) + throws IOException { + outputStream.write("{\"result\": \"hello\"}".getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java new file mode 100644 index 00000000000..9c55c8d5047 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationTest.java @@ -0,0 +1,566 @@ +import static datadog.trace.agent.test.assertions.Matchers.is; +import static datadog.trace.agent.test.assertions.SpanMatcher.span; +import static datadog.trace.agent.test.assertions.TagsMatcher.defaultTags; +import static datadog.trace.agent.test.assertions.TagsMatcher.error; +import static datadog.trace.agent.test.assertions.TagsMatcher.tag; +import static datadog.trace.agent.test.assertions.TraceMatcher.trace; +import static datadog.trace.api.gateway.Events.EVENTS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazonaws.services.lambda.runtime.ClientContext; +import com.amazonaws.services.lambda.runtime.CognitoIdentity; +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.LambdaLogger; +import datadog.trace.agent.test.AbstractInstrumentationTest; +import datadog.trace.api.DDSpanTypes; +import datadog.trace.api.function.TriConsumer; +import datadog.trace.api.function.TriFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.IGSpanInfo; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.gateway.SubscriptionService; +import datadog.trace.bootstrap.ActiveSubsystems; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; +import datadog.trace.junit.utils.config.WithConfig; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@WithConfig(key = "_HANDLER", value = "Handler", env = true, addPrefix = false) +abstract class LambdaHandlerInstrumentationTest extends AbstractInstrumentationTest { + + static final String REQUEST_ID = "test-request-id"; + + // Object to avoid bootstrap class in field type (TestClassShadowingExtension check) + Object ig; + + boolean appSecStarted; + String capturedMethod; + String capturedPath; + Map capturedHeaders; + Object capturedBody; + boolean appSecEnded; + + Integer capturedResponseStatus; + Map capturedResponseHeaders; + Object capturedResponseBody; + boolean responseHeaderDoneCalled; + + abstract int version(); + + abstract String operation(); + + @BeforeEach + void setUpAppSec() { + SubscriptionService ss = + (SubscriptionService) AgentTracer.get().getSubscriptionService(RequestContextSlot.APPSEC); + ig = ss; + ActiveSubsystems.APPSEC_ACTIVE = true; + + appSecStarted = false; + capturedMethod = null; + capturedPath = null; + capturedHeaders = new HashMap<>(); + capturedBody = null; + appSecEnded = false; + capturedResponseStatus = null; + capturedResponseHeaders = new HashMap<>(); + capturedResponseBody = null; + responseHeaderDoneCalled = false; + + ss.registerCallback( + EVENTS.requestStarted(), + (Supplier>) + () -> { + appSecStarted = true; + return new Flow.ResultFlow<>(new Object()); + }); + ss.registerCallback( + EVENTS.requestMethodUriRaw(), + (TriFunction>) + (ctx2, method2, uri) -> { + capturedMethod = method2; + capturedPath = uri.path(); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestHeader(), + (TriConsumer) + (ctx2, name, value) -> capturedHeaders.put(name, value)); + ss.registerCallback( + EVENTS.requestHeaderDone(), + (Function>) ctx2 -> Flow.ResultFlow.empty()); + ss.registerCallback( + EVENTS.requestBodyProcessed(), + (BiFunction>) + (ctx2, body) -> { + capturedBody = body; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestEnded(), + (BiFunction>) + (ctx2, spanInfo) -> { + appSecEnded = true; + return Flow.ResultFlow.empty(); + }); + + ss.registerCallback( + EVENTS.responseStarted(), + (BiFunction>) + (ctx2, status) -> { + capturedResponseStatus = status; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseHeader(), + (TriConsumer) + (ctx2, name, value) -> capturedResponseHeaders.put(name, value)); + ss.registerCallback( + EVENTS.responseHeaderDone(), + (Function>) + ctx2 -> { + responseHeaderDoneCalled = true; + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseBody(), + (BiFunction>) + (ctx2, body) -> { + capturedResponseBody = body; + return Flow.ResultFlow.empty(); + }); + } + + @AfterEach + void cleanUpAppSec() { + ((SubscriptionService) ig).reset(); + ActiveSubsystems.APPSEC_ACTIVE = false; + } + + private Context newContext() { + return new TestContext(REQUEST_ID); + } + + @Test + void testLambdaStreamingHandler() throws IOException { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void serverlessInvocationSpanResourceResetAfterHttpFrameworkOverwrite() throws IOException { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingSimulatesHttpFrameworkResource().handleRequest(input, output, newContext()); + + assertTraces( + trace( + span() + .resourceName(name -> operation().equals(name.toString())) + .type(DDSpanTypes.SERVERLESS) + .error(false))); + } + + @Test + void testStreamingHandlerWithError() { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + + assertThrows( + Error.class, + () -> new HandlerStreamingWithError().handleRequest(input, output, newContext())); + + assertTraces( + trace( + span() + .type(DDSpanTypes.SERVERLESS) + .error(true) + .tags( + defaultTags(), + tag("request_id", is(REQUEST_ID)), + error(Error.class, "Some error")))); + } + + @Test + void appSecCallbacksAreInvokedForApiGatewayV1Event() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/users/123\"," + + "\"headers\": {\"content-type\": \"application/json\"," + + " \"x-forwarded-for\": \"203.0.113.1\"}," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"requestId\": \"req-abc\"," + + " \"identity\": {\"sourceIp\": \"203.0.113.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("GET", capturedMethod); + assertEquals("/api/users/123", capturedPath); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertTrue(capturedBody instanceof Map); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void appSecCallbacksAreInvokedForApiGatewayV2HttpEvent() throws IOException { + String eventJson = + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"content-type\": \"application/json\"," + + " \"accept\": \"application/json\"}," + + "\"cookies\": [\"session=abc123\"]," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"http\": {" + + " \"method\": \"POST\"," + + " \"path\": \"/api/items\"," + + " \"sourceIp\": \"198.51.100.1\"" + + " }," + + " \"domainName\": \"api.example.com\"" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("POST", capturedMethod); + assertEquals("/api/items", capturedPath); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertEquals("session=abc123", capturedHeaders.get("cookie")); + assertTrue(capturedBody instanceof Map); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void appSecCallbacksAreNotInvokedWhenAppSecIsDisabled() throws IOException { + ActiveSubsystems.APPSEC_ACTIVE = false; + + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-xyz\"}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreaming().handleRequest(input, output, newContext()); + + assertFalse(appSecStarted); + assertNull(capturedMethod); + assertFalse(appSecEnded); + assertNull(capturedResponseStatus); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksAreInvokedForJsonEncodedResponse() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"identity\": {\"sourceIp\": \"127.0.0.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertEquals(200, (int) capturedResponseStatus); + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertEquals("custom-val", capturedResponseHeaders.get("x-custom")); + assertTrue(capturedResponseBody instanceof Map); + assertEquals("ok", ((Map) capturedResponseBody).get("result")); + assertTrue(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksReceiveCorrectDataFor404Response() throws IOException { + String eventJson = + "{" + + "\"path\": \"/missing\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWith404Response().handleRequest(input, output, newContext()); + + assertEquals(404, (int) capturedResponseStatus); + assertEquals("text/html", capturedResponseHeaders.get("content-type")); + assertEquals("Not Found", capturedResponseBody); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksApplyFallbackForLambdaUrlWithNonApiGatewayResponse() throws IOException { + // A Lambda Function URL handler returning plain JSON (no statusCode/headers/body structure) + // should trigger the fallback: no responseStarted (status unknown), content-type: + // application/json, full JSON as body. + String eventJson = + "{" + + "\"version\": \"2.0\"," + + "\"rawPath\": \"/\"," + + "\"headers\": {\"host\": \"example.lambda-url.us-east-1.on.aws\"}," + + "\"requestContext\": {" + + " \"domainName\": \"example.lambda-url.us-east-1.on.aws\"," + + " \"http\": {" + + " \"method\": \"GET\"," + + " \"path\": \"/\"," + + " \"sourceIp\": \"1.2.3.4\"" + + " }" + + "}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithRawJson().handleRequest(input, output, newContext()); + + assertNull(capturedResponseStatus); // no responseStarted for status-less fallback + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertTrue(capturedResponseBody instanceof Map); + assertEquals("hello", ((Map) capturedResponseBody).get("result")); + assertTrue(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksSkipNonApiGatewayResponseForNonHttpEvent() throws IOException { + // When the trigger type cannot be determined (non-JSON or non-HTTP event), + // response callbacks must not fire even if the response is valid JSON. + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithRawJson().handleRequest(input, output, newContext()); + + assertNull(capturedResponseStatus); + assertTrue(capturedResponseHeaders.isEmpty()); + assertNull(capturedResponseBody); + assertFalse(responseHeaderDoneCalled); + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseAndRequestCallbacksAreBothInvoked() throws IOException { + String eventJson = + "{" + + "\"path\": \"/api/users/123\"," + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"body\": \"{\\\"key\\\": \\\"value\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"POST\"," + + " \"requestId\": \"req-order-1\"," + + " \"identity\": {\"sourceIp\": \"10.0.0.1\"}" + + "}" + + "}"; + + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertTrue(appSecStarted); + assertEquals("POST", capturedMethod); + assertEquals("/api/users/123", capturedPath); + assertTrue(capturedBody instanceof Map); + + assertEquals(200, (int) capturedResponseStatus); + assertEquals("application/json", capturedResponseHeaders.get("content-type")); + assertTrue(capturedResponseBody instanceof Map); + + assertTrue(appSecEnded); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksFireBeforeRequestEnded() throws IOException { + List callOrder = new ArrayList<>(); + + // Reset and re-register to capture ordering + SubscriptionService ss = (SubscriptionService) ig; + ss.reset(); + + ss.registerCallback( + EVENTS.requestStarted(), + (Supplier>) () -> new Flow.ResultFlow<>(new Object())); + ss.registerCallback( + EVENTS.responseStarted(), + (BiFunction>) + (ctx2, status) -> { + callOrder.add("responseStarted"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseHeaderDone(), + (Function>) + ctx2 -> { + callOrder.add("responseHeaderDone"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.responseBody(), + (BiFunction>) + (ctx2, body) -> { + callOrder.add("responseBody"); + return Flow.ResultFlow.empty(); + }); + ss.registerCallback( + EVENTS.requestEnded(), + (BiFunction>) + (ctx2, spanInfo) -> { + callOrder.add("requestEnded"); + return Flow.ResultFlow.empty(); + }); + + String eventJson = + "{" + + "\"path\": \"/api/test\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"GET\"," + + " \"requestId\": \"req-order-2\"" + + "}" + + "}"; + ByteArrayInputStream input = + new ByteArrayInputStream(eventJson.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + new HandlerStreamingWithApiGwResponse().handleRequest(input, output, newContext()); + + assertTrue(callOrder.contains("responseStarted")); + assertTrue(callOrder.contains("responseHeaderDone")); + assertTrue(callOrder.contains("responseBody")); + assertTrue(callOrder.contains("requestEnded")); + assertTrue(callOrder.indexOf("responseStarted") < callOrder.indexOf("requestEnded")); + assertTrue(callOrder.indexOf("responseHeaderDone") < callOrder.indexOf("requestEnded")); + assertTrue(callOrder.indexOf("responseBody") < callOrder.indexOf("requestEnded")); + assertTraces(trace(span().type(DDSpanTypes.SERVERLESS).error(false))); + } + + @Test + void responseCallbacksReceiveNoDataWhenHandlerThrows() { + ByteArrayInputStream input = new ByteArrayInputStream("Hello".getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + + assertThrows( + Error.class, + () -> new HandlerStreamingWithError().handleRequest(input, output, newContext())); + + assertNull(capturedResponseStatus, "response status should not be set when handler throws"); + assertNull(capturedResponseBody, "response body should not be set when handler throws"); + assertTraces( + trace( + span() + .type(DDSpanTypes.SERVERLESS) + .error(true) + .tags( + defaultTags(), + tag("request_id", is(REQUEST_ID)), + error(Error.class, "Some error")))); + } + + private static final class TestContext implements Context { + private final String requestId; + + TestContext(String requestId) { + this.requestId = requestId; + } + + @Override + public String getAwsRequestId() { + return requestId; + } + + @Override + public String getLogGroupName() { + return null; + } + + @Override + public String getLogStreamName() { + return null; + } + + @Override + public String getFunctionName() { + return null; + } + + @Override + public String getFunctionVersion() { + return null; + } + + @Override + public String getInvokedFunctionArn() { + return null; + } + + @Override + public CognitoIdentity getIdentity() { + return null; + } + + @Override + public ClientContext getClientContext() { + return null; + } + + @Override + public int getRemainingTimeInMillis() { + return 0; + } + + @Override + public int getMemoryLimitInMB() { + return 0; + } + + @Override + public LambdaLogger getLogger() { + return null; + } + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java new file mode 100644 index 00000000000..2c669643a7e --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV0Test.java @@ -0,0 +1,15 @@ +import datadog.trace.junit.utils.config.WithConfig; + +@WithConfig(key = "trace.span.attribute.schema", value = "v0") +class LambdaHandlerInstrumentationV0Test extends LambdaHandlerInstrumentationTest { + + @Override + int version() { + return 0; + } + + @Override + String operation() { + return "dd-tracer-serverless-span"; + } +} diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java new file mode 100644 index 00000000000..c37027b64e2 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/java/LambdaHandlerInstrumentationV1ForkedTest.java @@ -0,0 +1,15 @@ +import datadog.trace.junit.utils.config.WithConfig; + +@WithConfig(key = "trace.span.attribute.schema", value = "v1") +class LambdaHandlerInstrumentationV1ForkedTest extends LambdaHandlerInstrumentationTest { + + @Override + int version() { + return 1; + } + + @Override + String operation() { + return "aws.lambda.invoke"; + } +} diff --git a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java index 28f9e39c710..247f967cff9 100644 --- a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java +++ b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java @@ -1255,7 +1255,8 @@ public void notifyExtensionEnd( } @Override - public void notifyAppSecEnd(AgentSpan span) { + public void notifyAppSecEnd(AgentSpan span, Object result) { + LambdaAppSecHandler.processResponseData(span, result); LambdaAppSecHandler.processRequestEnd(span); } diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 32bfa7fad6e..5198b63e294 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -8,6 +8,7 @@ import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.CallbackProvider; import datadog.trace.api.gateway.Flow; import datadog.trace.api.gateway.IGSpanInfo; import datadog.trace.api.gateway.RequestContext; @@ -22,16 +23,21 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Base64; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +56,10 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); + // Carries the detected trigger type from processRequestStart to processResponseData within the + // same Lambda invocation. Cleared in processRequestEnd. + private static final ThreadLocal CURRENT_TRIGGER_TYPE = new ThreadLocal<>(); + /** * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes * all relevant AppSec gateway callbacks. @@ -64,6 +74,8 @@ public static AgentSpanContext processRequestStart(Object event) { return null; } + CURRENT_TRIGGER_TYPE.set(LambdaTriggerType.UNKNOWN); + if (!(event instanceof ByteArrayInputStream)) { log.debug( "Event is not a ByteArrayInputStream, type: {}", @@ -76,6 +88,7 @@ public static AgentSpanContext processRequestStart(Object event) { if (eventData == LambdaEventData.EMPTY) { return null; } + CURRENT_TRIGGER_TYPE.set(eventData.triggerType); return processAppSecRequestData(eventData); } catch (Exception e) { log.debug("Failed to process AppSec request data", e); @@ -89,6 +102,8 @@ public static AgentSpanContext processRequestStart(Object event) { * @param span the current span */ public static void processRequestEnd(AgentSpan span) { + CURRENT_TRIGGER_TYPE.remove(); + if (!ActiveSubsystems.APPSEC_ACTIVE || span == null) { return; } @@ -106,6 +121,191 @@ public static void processRequestEnd(AgentSpan span) { } } + /** + * Process response data through WAF before the request context is closed. Extracts status code, + * headers, and body from the Lambda response and fires the corresponding gateway events. + * + * @param span the current span + * @param result the Lambda handler result (expected to be a ByteArrayOutputStream) + */ + public static void processResponseData(AgentSpan span, Object result) { + if (!ActiveSubsystems.APPSEC_ACTIVE + || span == null + || !(result instanceof ByteArrayOutputStream)) { + return; + } + + try { + byte[] bytes = ((ByteArrayOutputStream) result).toByteArray(); + if (bytes.length == 0 || bytes.length > MAX_EVENT_SIZE) { + log.debug( + "Response size {} exceeds limit {} or is empty, skipping response processing", + bytes.length, + MAX_EVENT_SIZE); + return; + } + + String json = new String(bytes, StandardCharsets.UTF_8); + LambdaResponseData responseData = extractResponseData(json); + + if (responseData == null || responseData.statusCode == 0) { + // No statusCode means this is not an API-GW formatted response, or JSON parsing failed. + // Only process for known HTTP trigger types, mirroring Python's asm_start_response. + LambdaTriggerType triggerType = CURRENT_TRIGGER_TYPE.get(); + if (triggerType == null || !triggerType.isHttp()) { + return; + } + if (responseData == null || (responseData.headers.isEmpty() && responseData.body == null)) { + // Parse failed or response has no API-GW structure (plain JSON body). + // Treat the full response as the body, mirroring Python's asm_start_response. + Object fallbackBody; + String fallbackContentType; + try { + fallbackBody = OBJECT_ADAPTER.fromJson(json); + fallbackContentType = "application/json"; + } catch (Exception e) { + fallbackBody = json; + fallbackContentType = "text/plain"; + } + Map fallbackHeaders = + Collections.singletonMap("content-type", fallbackContentType); + responseData = new LambdaResponseData(0, fallbackHeaders, fallbackBody); + } + // else: responseData has explicit headers/body fields — keep them, just skip + // responseStarted + // (statusCode remains 0, so the responseStarted guard below will not fire). + } + + RequestContext requestContext = span.getRequestContext(); + if (requestContext == null) { + log.debug("Span has no RequestContext, skipping response processing"); + return; + } + + AgentTracer.TracerAPI tracer = AgentTracer.get(); + CallbackProvider cbp = tracer.getCallbackProvider(RequestContextSlot.APPSEC); + + // Fire response gateway events. Flow results are intentionally ignored: blocking on response + // is not supported for Lambda because remote config is unavailable in that environment. + + // Fire responseStarted + if (responseData.statusCode > 0) { + BiFunction> responseStartedCb = + cbp.getCallback(EVENTS.responseStarted()); + if (responseStartedCb != null) { + responseStartedCb.apply(requestContext, responseData.statusCode); + } + } + + // Fire responseHeader for each allowed header + if (responseData.headers != null && !responseData.headers.isEmpty()) { + TriConsumer responseHeaderCb = + cbp.getCallback(EVENTS.responseHeader()); + if (responseHeaderCb != null) { + for (Map.Entry header : responseData.headers.entrySet()) { + responseHeaderCb.accept(requestContext, header.getKey(), header.getValue()); + } + } + } + + // Fire responseHeaderDone + Function> responseHeaderDoneCb = + cbp.getCallback(EVENTS.responseHeaderDone()); + if (responseHeaderDoneCb != null) { + responseHeaderDoneCb.apply(requestContext); + } + + // Fire responseBody + if (responseData.body != null) { + BiFunction> responseBodyCb = + cbp.getCallback(EVENTS.responseBody()); + if (responseBodyCb != null) { + responseBodyCb.apply(requestContext, responseData.body); + } + } + } catch (Exception e) { + log.debug("Failed to process AppSec response data", e); + } + } + + static LambdaResponseData extractResponseData(String json) { + try { + Map response = MAP_ADAPTER.fromJson(json); + if (response == null) { + return null; + } + + // Extract status code + int statusCode = 0; + Object statusCodeObj = response.get("statusCode"); + if (statusCodeObj instanceof Number) { + statusCode = ((Number) statusCodeObj).intValue(); + } + + // Extract headers — keys are lowercased to normalise casing across API GW / ALB variants + Map headers = new HashMap<>(); + Map rawHeaders = extractStringMap(response.get("headers")); + for (Map.Entry entry : rawHeaders.entrySet()) { + headers.put(entry.getKey().toLowerCase(Locale.ROOT), entry.getValue()); + } + + // Merge multiValueHeaders if present (API GW v1 / ALB), also lowercasing keys + Object multiValueHeadersObj = response.get("multiValueHeaders"); + if (multiValueHeadersObj instanceof Map) { + Map multiValueHeaders = (Map) multiValueHeadersObj; + for (Map.Entry entry : multiValueHeaders.entrySet()) { + if (entry.getKey() != null && entry.getValue() instanceof List) { + String key = String.valueOf(entry.getKey()).toLowerCase(Locale.ROOT); + List values = (List) entry.getValue(); + String joinedValue = + values.stream().map(String::valueOf).collect(Collectors.joining(", ")); + headers.put(key, joinedValue); + } + } + } + + // Extract body + Object body = null; + Object bodyObj = response.get("body"); + if (bodyObj != null) { + String bodyString = String.valueOf(bodyObj); + + // Handle base64 encoding + Object isBase64EncodedObj = response.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64EncodedObj) || "true".equals(isBase64EncodedObj)) { + try { + bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); + } catch (Exception e) { + log.debug("Failed to decode base64 response body", e); + bodyString = null; + } + } + + if (bodyString != null) { + String contentType = headers.get("content-type"); + + // If JSON content-type or unknown, attempt JSON parsing + // Normalise casing: media type tokens are case-insensitive per RFC 7231 + String contentTypeLower = + contentType == null ? null : contentType.toLowerCase(Locale.ROOT); + if (contentTypeLower == null + || contentTypeLower.contains("json") + || contentTypeLower.contains("javascript")) { + Object parsed = parseBodyAsJson(bodyString); + body = parsed != null ? parsed : bodyString; + } else { + body = bodyString; + } + } + } + + return new LambdaResponseData(statusCode, headers, body); + } catch (Exception e) { + log.debug("Failed to parse response data from JSON", e); + return null; + } + } + /** * Merge AppSec context data into extension context. * @@ -459,7 +659,7 @@ private static LambdaEventData extractAlbData( if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { // Handle multi-value headers (combine multiple values with comma) - headers = new java.util.HashMap<>(); + headers = new HashMap<>(); Object multiValueHeadersObj = event.get("multiValueHeaders"); if (multiValueHeadersObj instanceof Map) { Map rawHeaders = (Map) multiValueHeadersObj; @@ -578,7 +778,7 @@ private static LambdaEventData extractGenericData(Map event) { * values to strings, filtering out null entries. */ private static Map extractStringMap(Object mapObj) { - Map result = new java.util.HashMap<>(); + Map result = new HashMap<>(); if (mapObj instanceof Map) { Map rawMap = (Map) mapObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -614,7 +814,7 @@ private static Map extractPathParameters(Object pathParamsObj) { * Map> format expected by AppSec. */ private static Map> extractQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); + Map> result = new HashMap<>(); if (queryParamsObj instanceof Map) { Map rawMap = (Map) queryParamsObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -634,7 +834,7 @@ private static Map> extractQueryParameters(Object queryPara * List> format directly. */ private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { - Map> result = new java.util.HashMap<>(); + Map> result = new HashMap<>(); if (queryParamsObj instanceof Map) { Map rawMap = (Map) queryParamsObj; for (Map.Entry entry : rawMap.entrySet()) { @@ -642,7 +842,7 @@ private static Map> extractMultiValueQueryParameters(Object String key = String.valueOf(entry.getKey()); if (entry.getValue() instanceof java.util.List) { java.util.List values = (java.util.List) entry.getValue(); - java.util.List stringValues = new java.util.ArrayList<>(); + List stringValues = new ArrayList<>(); for (Object value : values) { if (value != null) { stringValues.add(String.valueOf(value)); @@ -730,8 +930,8 @@ private static Object extractBody(Map event) { String bodyString = String.valueOf(bodyObj); // Check if body is base64 encoded (API Gateway feature) - Boolean isBase64Encoded = (Boolean) event.get("isBase64Encoded"); - if (Boolean.TRUE.equals(isBase64Encoded)) { + Object isBase64EncodedObj = event.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64EncodedObj) || "true".equals(isBase64EncodedObj)) { try { bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); } catch (Exception e) { @@ -765,6 +965,15 @@ private static Object parseBodyAsJson(String body) { } } + /** Sets the current trigger type thread-local. Package-private for use in tests only. */ + static void setCurrentTriggerType(LambdaTriggerType type) { + if (type == null) { + CURRENT_TRIGGER_TYPE.remove(); + } else { + CURRENT_TRIGGER_TYPE.set(type); + } + } + /** * Temporary RequestContext implementation to hold AppSecRequestContext before a span is created. */ @@ -827,7 +1036,11 @@ enum LambdaTriggerType { ALB, // Application Load Balancer ALB_MULTI_VALUE, // ALB with multi-value headers LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger + UNKNOWN; // Unknown or unsupported trigger + + boolean isHttp() { + return this != UNKNOWN; + } } /** Object for Lambda event data needed for AppSec processing */ @@ -876,6 +1089,19 @@ static class LambdaEventData { } } + /** Data extracted from a Lambda response for WAF analysis */ + static class LambdaResponseData { + final int statusCode; + final Map headers; + final Object body; + + LambdaResponseData(int statusCode, Map headers, Object body) { + this.statusCode = statusCode; + this.headers = headers; + this.body = body; + } + } + /** URIDataAdapter implementation for Lambda events. */ private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java index eed9ac52e9b..0e6526aaf2b 100644 --- a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java @@ -6,9 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -28,6 +33,7 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; import datadog.trace.core.DDCoreJavaSpecification; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -40,102 +46,90 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -@SuppressWarnings("unchecked") -public class LambdaAppSecHandlerTest extends DDCoreJavaSpecification { +class LambdaAppSecHandlerTest extends DDCoreJavaSpecification { - private static boolean originalAppSecActive; - private static AgentTracer.TracerAPI originalTracer; + static boolean originalAppSecActive; + static AgentTracer.TracerAPI originalTracer; @BeforeAll - static void setupSpec() { + static void saveState() { originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE; originalTracer = AgentTracer.get(); } + @AfterAll + static void restoreAppSecState() { + ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive; + } + @BeforeEach - void setup() { + void enableAppSec() { ActiveSubsystems.APPSEC_ACTIVE = true; } @AfterEach - void cleanup() { - ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive; + void resetTracer() { AgentTracer.forceRegister(originalTracer); + LambdaAppSecHandler.setCurrentTriggerType(null); } // ============================================================================ - // processRequestStart basic tests + // processRequestStart — guard tests // ============================================================================ @Test void processRequestStartReturnsNullWhenAppSecIsDisabled() { ActiveSubsystems.APPSEC_ACTIVE = false; ByteArrayInputStream event = createInputStream("{\"test\": \"data\"}"); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test void processRequestStartReturnsNullForNonByteArrayInputStream() { - AgentSpanContext result = LambdaAppSecHandler.processRequestStart("not a stream"); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart("not a stream")); } @Test void processRequestStartReturnsNullForNullEvent() { - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(null); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(null)); } @Test void processRequestStartReturnsNullForOversizedEvent() { int maxSize = Config.get().getAppSecBodyParsingSizeLimit(); - String largeBody = repeatChar('x', maxSize + 1); - ByteArrayInputStream event = createInputStream(largeBody); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + char[] chars = new char[maxSize + 1]; + java.util.Arrays.fill(chars, 'x'); + ByteArrayInputStream event = createInputStream(new String(chars)); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test void processRequestStartReturnsNullForZeroSizeEvent() { ByteArrayInputStream event = createInputStream(""); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test - void processRequestStartReturnsNullForMalformedJSON() { + void processRequestStartReturnsNullForMalformedJson() { ByteArrayInputStream event = createInputStream("{invalid json"); - - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test - void streamCanBeReadMultipleTimesAfterProcessing() throws Exception { + void streamCanBeReadMultipleTimesAfterProcessing() throws IOException { String jsonData = "{\"test\": \"data\", \"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(jsonData); - LambdaAppSecHandler.processRequestStart(event); event.reset(); byte[] bytes = new byte[event.available()]; event.read(bytes); String content = new String(bytes, StandardCharsets.UTF_8); - assertEquals(jsonData, content); } @@ -279,26 +273,20 @@ void detectsLambdaUrlWhenHttpPresentButNoDomainName() { // ============================================================================ @Test + @SuppressWarnings("unchecked") void extractsApiGatewayV1RestDataCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/api/users/123\",\n" - + " \"httpMethod\": \"POST\",\n" - + " \"headers\": {\n" - + " \"Content-Type\": \"application/json\",\n" - + " \"Authorization\": \"Bearer token123\"\n" - + " },\n" - + " \"pathParameters\": {\n" - + " \"userId\": \"123\"\n" - + " },\n" - + " \"body\": \"{\\\"name\\\": \\\"John\\\"}\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\",\n" - + " \"requestId\": \"req-123\",\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"192.168.1.100\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/api/users/123\"," + + "\"httpMethod\": \"POST\"," + + "\"headers\": {\"Content-Type\": \"application/json\", \"Authorization\": \"Bearer token123\"}," + + "\"pathParameters\": {\"userId\": \"123\"}," + + "\"body\": \"{\\\"name\\\": \\\"John\\\"}\"," + + "\"requestContext\": {" + + " \"httpMethod\": \"POST\"," + + " \"requestId\": \"req-123\"," + + " \"identity\": {\"sourceIp\": \"192.168.1.100\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -306,8 +294,8 @@ void extractsApiGatewayV1RestDataCorrectly() { String[] capturedPath = {null}; Map capturedHeaders = new HashMap<>(); String[] capturedSourceIp = {null}; - Integer[] capturedSourcePort = {null}; - Map[] capturedPathParams = new Map[] {null}; + int[] capturedSourcePort = {-1}; + Map[] capturedPathParams = {null}; Object[] capturedBody = {null}; setupMockCallbacks( @@ -317,7 +305,7 @@ void extractsApiGatewayV1RestDataCorrectly() { capturedMethod[0] = method; capturedPath[0] = uri.path(); }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onHeader(capturedHeaders::put) .onSocketAddress( (ip, port) -> { capturedSourceIp[0] = ip; @@ -330,14 +318,14 @@ void extractsApiGatewayV1RestDataCorrectly() { assertNotNull(result); assertInstanceOf(TagContext.class, result); - assertEquals("POST", capturedMethod[0]); assertEquals("/api/users/123", capturedPath[0]); assertEquals("application/json", capturedHeaders.get("Content-Type")); assertEquals("Bearer token123", capturedHeaders.get("Authorization")); assertEquals("192.168.1.100", capturedSourceIp[0]); - assertEquals(Integer.valueOf(0), capturedSourcePort[0]); - assertEquals("123", ((Map) capturedPathParams[0]).get("userId")); + assertEquals(0, capturedSourcePort[0]); + assertNotNull(capturedPathParams[0]); + assertEquals("123", capturedPathParams[0].get("userId")); assertInstanceOf(Map.class, capturedBody[0]); assertEquals("John", ((Map) capturedBody[0]).get("name")); } @@ -345,26 +333,16 @@ void extractsApiGatewayV1RestDataCorrectly() { @Test void extractsApiGatewayV2HttpDataCorrectly() { String eventJson = - "{\n" - + " \"version\": \"2.0\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-custom-header\": \"custom-value\"\n" - + " },\n" - + " \"cookies\": [\"session=abc123\", \"user=john\"],\n" - + " \"pathParameters\": {\n" - + " \"id\": \"456\"\n" - + " },\n" - + " \"body\": \"test body\",\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"PUT\",\n" - + " \"path\": \"/api/items/456\",\n" - + " \"sourceIp\": \"10.0.0.50\",\n" - + " \"sourcePort\": 54321\n" - + " },\n" - + " \"domainName\": \"api.example.com\"\n" - + " }\n" + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"content-type\": \"application/json\", \"x-custom-header\": \"custom-value\"}," + + "\"cookies\": [\"session=abc123\", \"user=john\"]," + + "\"pathParameters\": {\"id\": \"456\"}," + + "\"body\": \"test body\"," + + "\"requestContext\": {" + + " \"http\": {\"method\": \"PUT\", \"path\": \"/api/items/456\", \"sourceIp\": \"10.0.0.50\", \"sourcePort\": 54321}," + + " \"domainName\": \"api.example.com\"" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -372,8 +350,8 @@ void extractsApiGatewayV2HttpDataCorrectly() { String[] capturedPath = {null}; Map capturedHeaders = new HashMap<>(); String[] capturedSourceIp = {null}; - Integer[] capturedSourcePort = {null}; - Map[] capturedPathParams = new Map[] {null}; + int[] capturedSourcePort = {-1}; + Map[] capturedPathParams = {null}; setupMockCallbacks( new Callbacks() @@ -382,7 +360,7 @@ void extractsApiGatewayV2HttpDataCorrectly() { capturedMethod[0] = method; capturedPath[0] = uri.path(); }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onHeader(capturedHeaders::put) .onSocketAddress( (ip, port) -> { capturedSourceIp[0] = ip; @@ -399,26 +377,21 @@ void extractsApiGatewayV2HttpDataCorrectly() { assertEquals("custom-value", capturedHeaders.get("x-custom-header")); assertEquals("session=abc123; user=john", capturedHeaders.get("cookie")); assertEquals("10.0.0.50", capturedSourceIp[0]); - assertEquals(Integer.valueOf(54321), capturedSourcePort[0]); - assertEquals("456", ((Map) capturedPathParams[0]).get("id")); + assertEquals(54321, capturedSourcePort[0]); + assertNotNull(capturedPathParams[0]); + assertEquals("456", capturedPathParams[0].get("id")); } @Test void extractsLambdaFunctionUrlDataCorrectly() { String eventJson = - "{\n" - + " \"version\": \"2.0\",\n" - + " \"headers\": {\n" - + " \"host\": \"xyz.lambda-url.us-east-1.on.aws\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/function/path\",\n" - + " \"sourceIp\": \"1.2.3.4\"\n" - + " },\n" - + " \"domainName\": \"xyz.lambda-url.us-east-1.on.aws\"\n" - + " }\n" + "{" + + "\"version\": \"2.0\"," + + "\"headers\": {\"host\": \"xyz.lambda-url.us-east-1.on.aws\"}," + + "\"requestContext\": {" + + " \"http\": {\"method\": \"GET\", \"path\": \"/function/path\", \"sourceIp\": \"1.2.3.4\"}," + + " \"domainName\": \"xyz.lambda-url.us-east-1.on.aws\"" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -443,18 +416,13 @@ void extractsLambdaFunctionUrlDataCorrectly() { @Test void extractsAlbDataCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/alb/test\",\n" - + " \"httpMethod\": \"DELETE\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-for\": \"203.0.113.42\",\n" - + " \"user-agent\": \"curl/7.64.1\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/alb/test\"," + + "\"httpMethod\": \"DELETE\"," + + "\"headers\": {\"x-forwarded-for\": \"203.0.113.42\", \"user-agent\": \"curl/7.64.1\"}," + + "\"requestContext\": {" + + " \"elb\": {\"targetGroupArn\": \"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/tg/50dc6c495c0c9188\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -482,24 +450,17 @@ void extractsAlbDataCorrectly() { @Test void extractsAlbMultiValueHeadersCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"httpMethod\": \"GET\",\n" - + " \"multiValueHeaders\": {\n" - + " \"accept\": [\"text/html\", \"application/json\"],\n" - + " \"x-custom\": [\"value1\", \"value2\"]\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:...\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/test\"," + + "\"httpMethod\": \"GET\"," + + "\"multiValueHeaders\": {\"accept\": [\"text/html\", \"application/json\"], \"x-custom\": [\"value1\", \"value2\"]}," + + "\"requestContext\": {\"elb\": {\"targetGroupArn\": \"arn:aws:...\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -511,43 +472,34 @@ void extractsAlbMultiValueHeadersCorrectly() { @Test void handlesMultiValueHeadersWithEmptyList() { String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"httpMethod\": \"GET\",\n" - + " \"multiValueHeaders\": {\n" - + " \"accept\": [],\n" - + " \"x-custom\": [\"value1\"]\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"elb\": {\n" - + " \"targetGroupArn\": \"arn:aws:...\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/test\"," + + "\"httpMethod\": \"GET\"," + + "\"multiValueHeaders\": {\"accept\": [], \"x-custom\": [\"value1\"]}," + + "\"requestContext\": {\"elb\": {\"targetGroupArn\": \"arn:aws:...\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedHeaders.get("accept")); // Empty list should result in empty string + assertEquals("", capturedHeaders.get("accept")); assertEquals("value1", capturedHeaders.get("x-custom")); } @Test void extractsWebSocketDataCorrectly() { String eventJson = - "{\n" - + " \"requestContext\": {\n" - + " \"routeKey\": \"$connect\",\n" - + " \"connectionId\": \"conn-abc123\",\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"192.168.0.100\"\n" - + " }\n" - + " }\n" + "{" + + "\"requestContext\": {" + + " \"routeKey\": \"$connect\"," + + " \"connectionId\": \"conn-abc123\"," + + " \"identity\": {\"sourceIp\": \"192.168.0.100\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -577,14 +529,12 @@ void handlesBase64EncodedBodyCorrectly() { String originalBody = "This is test data"; String base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()); String eventJson = - "{\n" - + " \"body\": \"" + "{" + + "\"body\": \"" + base64Body - + "\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + + "\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -610,7 +560,7 @@ void handlesNullBodyCorrectly() { AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("NOT_CALLED", capturedBody[0]); // Callback should not be invoked for null body + assertEquals("NOT_CALLED", capturedBody[0]); } @Test @@ -625,17 +575,15 @@ void handlesEmptyBodyCorrectly() { AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedBody[0]); // Empty body is passed as empty string to WAF + assertEquals("", capturedBody[0]); } @Test void handlesPathWithQueryStringCorrectly() { String eventJson = - "{\n" - + " \"path\": \"/api/users?id=123&filter=active\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\"\n" - + " }\n" + "{" + + "\"path\": \"/api/users?id=123&filter=active\"," + + "\"requestContext\": {\"httpMethod\": \"GET\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -660,21 +608,15 @@ void handlesPathWithQueryStringCorrectly() { @Test void extractsSchemeAndPortFromXForwardedHeaders() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-proto\": \"http\",\n" - + " \"x-forwarded-port\": \"8080\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"x-forwarded-proto\": \"http\", \"x-forwarded-port\": \"8080\"}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( new Callbacks() @@ -688,24 +630,21 @@ void extractsSchemeAndPortFromXForwardedHeaders() { assertNotNull(result); assertEquals("http", capturedScheme[0]); - assertEquals(Integer.valueOf(8080), capturedPort[0]); + assertEquals(8080, capturedPort[0]); } @Test void fallsBackToHttps443WhenXForwardedHeadersAreAbsent() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {},\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( new Callbacks() @@ -719,27 +658,21 @@ void fallsBackToHttps443WhenXForwardedHeadersAreAbsent() { assertNotNull(result); assertEquals("https", capturedScheme[0]); - assertEquals(Integer.valueOf(443), capturedPort[0]); + assertEquals(443, capturedPort[0]); } @Test void handlesInvalidXForwardedPortGracefully() { String eventJson = - "{\n" - + " \"path\": \"/api/test\",\n" - + " \"headers\": {\n" - + " \"x-forwarded-proto\": \"https\",\n" - + " \"x-forwarded-port\": \"not-a-number\"\n" - + " },\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\",\n" - + " \"requestId\": \"req-123\"\n" - + " }\n" + "{" + + "\"path\": \"/api/test\"," + + "\"headers\": {\"x-forwarded-proto\": \"https\", \"x-forwarded-port\": \"not-a-number\"}," + + "\"requestContext\": {\"httpMethod\": \"GET\", \"requestId\": \"req-123\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); String[] capturedScheme = {null}; - Integer[] capturedPort = {null}; + int[] capturedPort = {-1}; setupMockCallbacks( new Callbacks() @@ -753,18 +686,16 @@ void handlesInvalidXForwardedPortGracefully() { assertNotNull(result); assertEquals("https", capturedScheme[0]); - assertEquals(Integer.valueOf(443), capturedPort[0]); + assertEquals(443, capturedPort[0]); } @Test void handlesInvalidBase64BodyGracefully() { String eventJson = - "{\n" - + " \"body\": \"not-valid-base64\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + "{" + + "\"body\": \"not-valid-base64\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -775,42 +706,39 @@ void handlesInvalidBase64BodyGracefully() { AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("NOT_CALLED", capturedBody[0]); // Should not call body callback when decode fails + assertEquals("NOT_CALLED", capturedBody[0]); } @Test void handlesBase64DecodedEmptyStringBody() { String base64Empty = Base64.getEncoder().encodeToString("".getBytes()); String eventJson = - "{\n" - + " \"body\": \"" + "{" + + "\"body\": \"" + base64Empty - + "\",\n" - + " \"isBase64Encoded\": true,\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + + "\"," + + "\"isBase64Encoded\": true," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); - Object[] capturedBody = {"NOT_CALLED"}; + Object[] capturedBody = {null}; setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertEquals("", capturedBody[0]); // Should pass empty string after decoding + assertEquals("", capturedBody[0]); } @Test + @SuppressWarnings("unchecked") void handlesBodyWithSpecialCharacters() { String eventJson = - "{\n" - + " \"body\": \"{\\\"text\\\": \\\"Hello \\u4e16\\u754c \\uD83C\\uDF0D\\\"}\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"POST\"\n" - + " }\n" + "{" + + "\"body\": \"{\\\"text\\\": \\\"Hello \\u4e16\\u754c \\uD83C\\uDF0D\\\"}\"," + + "\"requestContext\": {\"httpMethod\": \"POST\"}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -832,18 +760,12 @@ void handlesBodyWithSpecialCharacters() { @Test void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { String eventJson = - "{\n" - + " \"path\": \"/generic/path\",\n" - + " \"httpMethod\": \"PATCH\",\n" - + " \"headers\": {\n" - + " \"x-custom-header\": \"generic-value\"\n" - + " },\n" - + " \"unknownField\": \"should be ignored\",\n" - + " \"requestContext\": {\n" - + " \"identity\": {\n" - + " \"sourceIp\": \"203.0.113.1\"\n" - + " }\n" - + " }\n" + "{" + + "\"path\": \"/generic/path\"," + + "\"httpMethod\": \"PATCH\"," + + "\"headers\": {\"x-custom-header\": \"generic-value\"}," + + "\"unknownField\": \"should be ignored\"," + + "\"requestContext\": {\"identity\": {\"sourceIp\": \"203.0.113.1\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -859,7 +781,7 @@ void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { capturedMethod[0] = method; capturedPath[0] = uri.path(); }) - .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onHeader(capturedHeaders::put) .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -874,14 +796,10 @@ void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { @Test void extractsDataFromUnknownTriggerWithHttpInRequestContext() { String eventJson = - "{\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"OPTIONS\",\n" - + " \"path\": \"/options/path\",\n" - + " \"sourceIp\": \"198.51.100.50\"\n" - + " }\n" - + " }\n" + "{" + + "\"requestContext\": {" + + " \"http\": {\"method\": \"OPTIONS\", \"path\": \"/options/path\", \"sourceIp\": \"198.51.100.50\"}" + + "}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -909,23 +827,16 @@ void extractsDataFromUnknownTriggerWithHttpInRequestContext() { @Test void handlesCookiesMergingWithExistingCookieHeader() { String eventJson = - "{\n" - + " \"headers\": {\n" - + " \"cookie\": \"existing=value\"\n" - + " },\n" - + " \"cookies\": [\"new=cookie1\", \"another=cookie2\"],\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/\"\n" - + " }\n" - + " }\n" + "{" + + "\"headers\": {\"cookie\": \"existing=value\"}," + + "\"cookies\": [\"new=cookie1\", \"another=cookie2\"]," + + "\"requestContext\": {\"http\": {\"method\": \"GET\", \"path\": \"/\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); @@ -936,28 +847,21 @@ void handlesCookiesMergingWithExistingCookieHeader() { @Test void handlesEmptyCookiesArrayCorrectly() { String eventJson = - "{\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"cookies\": [],\n" - + " \"requestContext\": {\n" - + " \"http\": {\n" - + " \"method\": \"GET\",\n" - + " \"path\": \"/\"\n" - + " }\n" - + " }\n" + "{" + + "\"headers\": {\"content-type\": \"application/json\"}," + + "\"cookies\": []," + + "\"requestContext\": {\"http\": {\"method\": \"GET\", \"path\": \"/\"}}" + "}"; ByteArrayInputStream event = createInputStream(eventJson); Map capturedHeaders = new HashMap<>(); - setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + setupMockCallbacks(new Callbacks().onHeader(capturedHeaders::put)); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); assertNotNull(result); - assertFalse(capturedHeaders.containsKey("cookie")); // Empty array should not add cookie header + assertFalse(capturedHeaders.containsKey("cookie")); } // ============================================================================ @@ -966,8 +870,8 @@ void handlesEmptyCookiesArrayCorrectly() { @Test void processRequestEndDoesNothingWhenSpanIsNull() { - // No exception should be thrown LambdaAppSecHandler.processRequestEnd(null); + // no exception expected } @Test @@ -984,12 +888,12 @@ void processRequestEndDoesNothingWhenAppSecIsDisabled() { void processRequestEndDoesNothingWhenSpanHasNoRequestContext() { AgentSpan span = mock(AgentSpan.class); when(span.getRequestContext()).thenReturn(null); - - // No exception should be thrown LambdaAppSecHandler.processRequestEnd(span); + // no exception expected } @Test + @SuppressWarnings("unchecked") void processRequestEndInvokesRequestEndedCallbackWithRequestContext() { Object mockAppSecContext = new Object(); RequestContext mockRequestContext = mock(RequestContext.class); @@ -1001,31 +905,29 @@ void processRequestEndInvokesRequestEndedCallbackWithRequestContext() { RequestContext[] capturedContext = {null}; AgentSpan[] capturedSpan = {null}; - BiFunction> mockRequestEndedCallback = + BiFunction> requestEndedCallback = mock(BiFunction.class); doAnswer( inv -> { + callbackInvoked[0] = true; capturedContext[0] = inv.getArgument(0); capturedSpan[0] = inv.getArgument(1); - callbackInvoked[0] = true; return new Flow.ResultFlow<>(null); }) - .when(mockRequestEndedCallback) - .apply(any(), any()); + .when(requestEndedCallback) + .apply(any(RequestContext.class), any()); CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); - when(mockCallbackProvider.getCallback(EVENTS.requestEnded())) - .thenReturn(mockRequestEndedCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestEnded())).thenReturn(requestEndedCallback); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); LambdaAppSecHandler.processRequestEnd(span); - assertEquals(true, callbackInvoked[0]); + assertTrue(callbackInvoked[0]); assertEquals(mockRequestContext, capturedContext[0]); assertEquals(span, capturedSpan[0]); } @@ -1042,11 +944,10 @@ void processRequestEndHandlesNullRequestEndedCallbackGracefully() { AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); - // No exception should be thrown - should log warning but not throw LambdaAppSecHandler.processRequestEnd(span); + // no exception expected } // ============================================================================ @@ -1055,37 +956,26 @@ void processRequestEndHandlesNullRequestEndedCallbackGracefully() { @Test void mergeContextsReturnsNullWhenBothContextsAreNull() { - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, null); - - assertNull(result); + assertNull(LambdaAppSecHandler.mergeContexts(null, null)); } @Test void mergeContextsReturnsExtensionContextWhenAppSecContextIsNull() { TagContext extensionContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, null); - - assertEquals(extensionContext, result); + assertEquals(extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, null)); } @Test void mergeContextsReturnsAppSecContextWhenExtensionContextIsNull() { TagContext appSecContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, appSecContext); - - assertEquals(appSecContext, result); + assertEquals(appSecContext, LambdaAppSecHandler.mergeContexts(null, appSecContext)); } @Test void mergeContextsMergesAppSecDataIntoTagContext() { Object appSecData = new Object(); - - // Create real TagContext instances since methods are final TagContext appSecContext = new TagContext(); appSecContext.withRequestContextDataAppSec(appSecData); - TagContext extensionContext = new TagContext(); AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); @@ -1098,20 +988,16 @@ void mergeContextsMergesAppSecDataIntoTagContext() { void mergeContextsReturnsExtensionContextWhenAppSecContextIsNotTagContext() { TagContext extensionContext = mock(TagContext.class); AgentSpanContext appSecContext = mock(AgentSpanContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); - - assertEquals(extensionContext, result); + assertEquals( + extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext)); } @Test void mergeContextsReturnsExtensionContextWhenItIsNotTagContext() { AgentSpanContext extensionContext = mock(AgentSpanContext.class); TagContext appSecContext = mock(TagContext.class); - - AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); - - assertEquals(extensionContext, result); + assertEquals( + extensionContext, LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext)); } // ============================================================================ @@ -1119,6 +1005,7 @@ void mergeContextsReturnsExtensionContextWhenItIsNotTagContext() { // ============================================================================ @Test + @SuppressWarnings("unchecked") void processRequestStartHandlesNullRequestStartedCallbackGracefully() { String eventJson = "{\"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(eventJson); @@ -1129,80 +1016,611 @@ void processRequestStartHandlesNullRequestStartedCallbackGracefully() { AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - - assertNull(result); // Should return null when requestStarted callback is missing + assertNull(LambdaAppSecHandler.processRequestStart(event)); } @Test + @SuppressWarnings("unchecked") void processRequestStartHandlesNullMethodUriCallbackGracefully() { - String eventJson = - "{\n" - + " \"path\": \"/test\",\n" - + " \"requestContext\": {\n" - + " \"httpMethod\": \"GET\"\n" - + " }\n" - + "}"; + String eventJson = "{\"path\": \"/test\", \"requestContext\": {\"httpMethod\": \"GET\"}}"; ByteArrayInputStream event = createInputStream(eventJson); Object mockAppSecContext = new Object(); - - Supplier> mockRequestStartedCallback = mock(Supplier.class); - when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); - - Function> mockHeaderDoneCallback = mock(Function.class); - when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + Supplier> requestStartedCallback = mock(Supplier.class); + when(requestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) - .thenReturn(mockRequestStartedCallback); + .thenReturn(requestStartedCallback); when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())).thenReturn(null); + Function> headerDoneCallback = mock(Function.class); + when(headerDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) - .thenReturn(mockHeaderDoneCallback); + .thenReturn(headerDoneCallback); when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())).thenReturn(null); when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())).thenReturn(null); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); - assertNotNull(result); // Should continue processing even if methodUri callback is null + assertNotNull(result); assertInstanceOf(TagContext.class, result); } @Test void processRequestStartHandlesExceptionDuringJsonParsing() { - String invalidJson = "{this is not valid JSON at all"; - ByteArrayInputStream event = createInputStream(invalidJson); + ByteArrayInputStream event = createInputStream("{this is not valid JSON at all"); + assertNull(LambdaAppSecHandler.processRequestStart(event)); + } - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + @Test + void processRequestStartHandlesExceptionDuringStreamReading() { + ByteArrayInputStream mockStream = + new ByteArrayInputStream("data".getBytes()) { + @Override + public synchronized int available() { + throw new RuntimeException("Stream error"); + } + }; + assertNull(LambdaAppSecHandler.processRequestStart(mockStream)); + } + + // ============================================================================ + // processResponseData Tests — guard conditions + // ============================================================================ + + @Test + void processResponseDataDoesNothingWhenAppSecIsDisabled() { + ActiveSubsystems.APPSEC_ACTIVE = false; + AgentSpan span = mock(AgentSpan.class); + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": \"ok\"}"); + LambdaAppSecHandler.processResponseData(span, result); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingForNullSpan() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + LambdaAppSecHandler.processResponseData(null, result); + // no exception expected + } + + @Test + void processResponseDataDoesNothingForNonByteArrayOutputStreamResult() { + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, "string result"); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingForNullResult() { + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, null); + verify(span, never()).getRequestContext(); + } + + @Test + void processResponseDataDoesNothingWhenSpanHasNoRequestContext() { + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(null); + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + setupMockResponseCallbacks(null, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected + } + + @Test + void processResponseDataDoesNothingForOversizedResponse() { + int maxSize = Config.get().getAppSecBodyParsingSizeLimit(); + char[] chars = new char[maxSize + 1]; + java.util.Arrays.fill(chars, 'x'); + ByteArrayOutputStream result = createOutputStream(new String(chars)); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataDoesNothingForEmptyByteArrayOutputStream() { + ByteArrayOutputStream result = new ByteArrayOutputStream(); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Trigger type gating and fallback --- + + @Test + void processResponseDataSkipsNonApiGwResponseWhenTriggerTypeIsUnknown() { + LambdaAppSecHandler.setCurrentTriggerType(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN); + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, null, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertFalse(headerDoneCalled[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataAppliesFallbackForHttpTriggerWithPlainJsonResponse() { + LambdaAppSecHandler.setCurrentTriggerType(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL); + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + capturedHeaders::put, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertTrue(headerDoneCalled[0]); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("hello", ((Map) capturedBody[0]).get("result")); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataKeepsParsedHeadersAndBodyWhenStatusCodeIsZero() { + // A response that has statusCode:0 with explicit headers/body should use the parsed data, + // not discard it in favour of the plain-response fallback. + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST); + ByteArrayOutputStream result = + createOutputStream( + "{\"statusCode\": 0, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"hello\"}"); + Integer[] capturedStatus = {null}; + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + capturedHeaders::put, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); // statusCode 0 — responseStarted not fired + assertEquals("text/plain", capturedHeaders.get("content-type")); // parsed header kept + assertTrue(headerDoneCalled[0]); + assertEquals("hello", capturedBody[0]); // parsed body kept, not the whole envelope + } + + @Test + void processResponseDataAppliesFallbackForHttpTriggerWithNonJsonStringResponse() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST); + // JSON-encoded string (as returned by a RequestHandler) + ByteArrayOutputStream result = createOutputStream("\"Hello World!\""); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + null, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + assertTrue(headerDoneCalled[0]); + assertEquals("Hello World!", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataWebSocketWithStatusCodeFiresResponseStarted() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET); + // $connect handler returning a proper statusCode — should be treated like any API-GW response + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, null, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(200, capturedStatus[0]); + assertTrue(headerDoneCalled[0]); + } + + @Test + void processResponseDataWebSocketWithoutStatusCodeUsesFallbackWithNoStatus() { + LambdaAppSecHandler.setCurrentTriggerType( + LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET); + // Message-route handler returning arbitrary data — no statusCode, fallback path + ByteArrayOutputStream result = createOutputStream("{\"message\": \"hello\"}"); + Integer[] capturedStatus = {null}; + boolean[] headerDoneCalled = {false}; + Object[] capturedBody = {null}; + AgentSpan span = + setupMockResponseCallbacks( + status -> capturedStatus[0] = status, + null, + () -> headerDoneCalled[0] = true, + body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); // no responseStarted for status-less WebSocket messages + assertTrue(headerDoneCalled[0]); + assertInstanceOf(Map.class, capturedBody[0]); + } + + @Test + void processResponseDataSkipsNonApiGwResponseWhenTriggerTypeIsNull() { + // No processRequestStart called — thread-local is null — behaves like unknown + ByteArrayOutputStream result = createOutputStream("{\"result\": \"hello\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Status code extraction --- + + @Test + void processResponseDataExtractsStatusCodeCorrectly() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(200, capturedStatus[0]); + } + + @Test + void processResponseDataExtractsStatusCodeAsIntegerFromDouble() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 404.0, \"body\": \"not found\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(404, capturedStatus[0]); + } + + @Test + void processResponseDataHandlesMissingStatusCode() { + ByteArrayOutputStream result = createOutputStream("{\"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataHandlesNonNumericStatusCode() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": \"bad\", \"body\": \"ok\"}"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + // --- Header extraction --- + + @Test + void processResponseDataForwardsAllResponseHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\", \"x-custom\": \"val\", \"content-length\": \"42\", \"set-cookie\": \"a=1\"}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals(4, capturedHeaders.size()); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertEquals("val", capturedHeaders.get("x-custom")); + assertEquals("42", capturedHeaders.get("content-length")); + assertEquals("a=1", capturedHeaders.get("set-cookie")); + } + + @Test + void processResponseDataLowercasesHeaderKeys() { + String json = + "{\"statusCode\": 200, \"headers\": {\"Content-Type\": \"text/html\", \"CONTENT-LENGTH\": \"10\"}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("10", capturedHeaders.get("content-length")); + } + + @Test + void processResponseDataMergesMultiValueHeadersWithSingleValueHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"content-encoding\": [\"gzip\", \"br\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("gzip, br", capturedHeaders.get("content-encoding")); + } + + @Test + void processResponseDataHandlesEmptyHeaders() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + Map capturedHeaders = new HashMap<>(); + boolean[] headerDoneCalled = {false}; + AgentSpan span = + setupMockResponseCallbacks( + null, capturedHeaders::put, () -> headerDoneCalled[0] = true, null); + LambdaAppSecHandler.processResponseData(span, result); + assertTrue(capturedHeaders.isEmpty()); + assertTrue(headerDoneCalled[0]); + } + + // --- Body extraction --- + + @Test + @SuppressWarnings("unchecked") + void processResponseDataParsesJsonBody() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\"}, \"body\": \"{\\\"key\\\": \\\"value\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("value", ((Map) capturedBody[0]).get("key")); + } + + @Test + void processResponseDataHandlesNonJsonBodyAsRawString() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"plain text\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("plain text", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataHandlesBase64EncodedBody() { + String originalBody = "{\"decoded\": \"content\"}"; + String base64Body = + Base64.getEncoder().encodeToString(originalBody.getBytes(StandardCharsets.UTF_8)); + String json = + "{\"statusCode\": 200, \"body\": \"" + base64Body + "\", \"isBase64Encoded\": true}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("content", ((Map) capturedBody[0]).get("decoded")); + } + + @Test + void processResponseDataHandlesNullBody() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200, \"body\": null}"); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + void processResponseDataHandlesMissingBodyField() { + ByteArrayOutputStream result = createOutputStream("{\"statusCode\": 200}"); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataAttemptsJsonParseWhenNoContentType() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 200, \"body\": \"{\\\"a\\\": 1}\"}"); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals(1.0d, ((Map) capturedBody[0]).get("a")); + } + + @Test + void processResponseDataFallsBackToRawStringWhenJsonParseFails() { + ByteArrayOutputStream result = + createOutputStream("{\"statusCode\": 200, \"body\": \"not json {\"}"); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("not json {", capturedBody[0]); + } + + // --- Event ordering --- + + @Test + void processResponseDataFiresEventsInCorrectOrder() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/json\"}, \"body\": \"{\\\"k\\\": \\\"v\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + java.util.List order = new java.util.ArrayList<>(); + + AgentSpan span = + setupMockResponseCallbacks( + status -> order.add("responseStarted"), + (name, value) -> order.add("responseHeader"), + () -> order.add("responseHeaderDone"), + body -> order.add("responseBody")); + + LambdaAppSecHandler.processResponseData(span, result); + + assertEquals("responseStarted", order.get(0)); + assertTrue(order.stream().filter("responseHeader"::equals).count() >= 1); + int headerDoneIdx = order.indexOf("responseHeaderDone"); + int lastHeaderIdx = order.lastIndexOf("responseHeader"); + assertTrue(headerDoneIdx > lastHeaderIdx); + assertEquals("responseBody", order.get(order.size() - 1)); + } + + @Test + void processResponseDataHandlesInvalidBase64ResponseBodyGracefully() { + String json = + "{\"statusCode\": 200, \"body\": \"not-valid-base64!!!\", \"isBase64Encoded\": true}"; + ByteArrayOutputStream result = createOutputStream(json); + String[] capturedBody = {"NOT_CALLED"}; + AgentSpan span = + setupMockResponseCallbacks( + null, null, null, body -> capturedBody[0] = String.valueOf(body)); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("NOT_CALLED", capturedBody[0]); + } + + @Test + @SuppressWarnings("unchecked") + void processResponseDataParsesBodyAsJsonForJavascriptContentType() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"application/javascript\"}, \"body\": \"{\\\"key\\\": \\\"val\\\"}\"}"; + ByteArrayOutputStream result = createOutputStream(json); + Object[] capturedBody = {null}; + AgentSpan span = setupMockResponseCallbacks(null, null, null, body -> capturedBody[0] = body); + LambdaAppSecHandler.processResponseData(span, result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("val", ((Map) capturedBody[0]).get("key")); + } + + @Test + void processResponseDataSkipsMultiValueHeadersEntryWithNonListValue() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"x-scalar\": \"not-a-list\", \"x-valid\": [\"v1\", \"v2\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("text/html", capturedHeaders.get("content-type")); + assertEquals("v1, v2", capturedHeaders.get("x-valid")); + assertFalse(capturedHeaders.containsKey("x-scalar")); + } + + @Test + void processResponseDataMultiValueHeadersOverrideSingleValueHeaders() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/html\"}, \"multiValueHeaders\": {\"content-type\": [\"application/json\", \"charset=utf-8\"]}}"; + ByteArrayOutputStream result = createOutputStream(json); + Map capturedHeaders = new HashMap<>(); + AgentSpan span = setupMockResponseCallbacks(null, capturedHeaders::put, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertEquals("application/json, charset=utf-8", capturedHeaders.get("content-type")); + } + + // --- Error handling --- + + @Test + void processResponseDataHandlesMalformedJsonResponse() { + ByteArrayOutputStream result = createOutputStream("{not valid json"); + Integer[] capturedStatus = {null}; + AgentSpan span = + setupMockResponseCallbacks(status -> capturedStatus[0] = status, null, null, null); + LambdaAppSecHandler.processResponseData(span, result); + assertNull(capturedStatus[0]); + } + + @Test + void processResponseDataHandlesEmptyStringResponse() { + ByteArrayOutputStream result = createOutputStream(""); + AgentSpan span = mock(AgentSpan.class); + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected + } + + // ============================================================================ + // processResponseData — null individual callback handling + // ============================================================================ + + @Test + void processResponseDataHandlesNullResponseHeaderDoneCallbackGracefully() { + String json = + "{\"statusCode\": 200, \"headers\": {\"content-type\": \"text/plain\"}, \"body\": \"ok\"}"; + ByteArrayOutputStream result = createOutputStream(json); + + RequestContext mockRequestContext = mock(RequestContext.class); + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(mockRequestContext); + + CallbackProvider cbp = mock(CallbackProvider.class); + when(cbp.getCallback(EVENTS.responseStarted())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseHeader())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseHeaderDone())).thenReturn(null); + when(cbp.getCallback(EVENTS.responseBody())).thenReturn(null); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)).thenReturn(cbp); + AgentTracer.forceRegister(mockTracer); - assertNull(result); // Should return null on parse error + LambdaAppSecHandler.processResponseData(span, result); + // no exception expected — all null callbacks must be silently skipped } + // ============================================================================ + // extractResponseData Unit Tests + // ============================================================================ + @Test - void processRequestStartHandlesExceptionDuringStreamReading() throws IOException { - ByteArrayInputStream mockStream = mock(ByteArrayInputStream.class); - when(mockStream.available()).thenThrow(new IOException("Stream error")); + void extractResponseDataReturnsNullForMalformedJson() { + assertNull(LambdaAppSecHandler.extractResponseData("{bad json")); + } - AgentSpanContext result = LambdaAppSecHandler.processRequestStart(mockStream); + @Test + void extractResponseDataReturnsNullForNullJsonParseResult() { + assertNull(LambdaAppSecHandler.extractResponseData("null")); + } - assertNull(result); // Should return null on IO error + @Test + void extractResponseDataReturnsNullForEmptyString() { + assertNull(LambdaAppSecHandler.extractResponseData("")); } // ============================================================================ - // Helper classes and methods + // Helper Methods // ============================================================================ + private static ByteArrayInputStream createInputStream(String json) { + return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); + } + + private static ByteArrayOutputStream createOutputStream(String json) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + baos.write(json.getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + throw new RuntimeException(e); + } + return baos; + } + private static class Callbacks { BiConsumer onMethodUri; BiConsumer onHeader; @@ -1236,125 +1654,184 @@ Callbacks onBody(Consumer cb) { } } - private static Map mapOf(Object... keysAndValues) { - Map map = new LinkedHashMap<>(); - for (int i = 0; i < keysAndValues.length; i += 2) { - map.put((String) keysAndValues[i], keysAndValues[i + 1]); - } - return map; - } - - private ByteArrayInputStream createInputStream(String json) { - return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); - } - + @SuppressWarnings("unchecked") private void setupMockCallbacks(Callbacks callbacks) { Object mockAppSecContext = new Object(); + Supplier> requestStartedCallback = mock(Supplier.class); + when(requestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); - Supplier> mockRequestStartedCallback = mock(Supplier.class); - when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); - - TriFunction> mockMethodUriCallback = null; + TriFunction> methodUriCallback = null; if (callbacks.onMethodUri != null) { - mockMethodUriCallback = mock(TriFunction.class); - BiConsumer methodUriCb = callbacks.onMethodUri; + methodUriCallback = mock(TriFunction.class); + BiConsumer capture = callbacks.onMethodUri; doAnswer( inv -> { - String method = inv.getArgument(1); - URIDataAdapter uri = inv.getArgument(2); - methodUriCb.accept(method, uri); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1), inv.getArgument(2)); + return Flow.ResultFlow.empty(); }) - .when(mockMethodUriCallback) - .apply(any(), any(), any()); + .when(methodUriCallback) + .apply(any(), anyString(), any(URIDataAdapter.class)); } - TriConsumer mockHeaderCallback = null; + TriConsumer headerCallback = null; if (callbacks.onHeader != null) { - mockHeaderCallback = mock(TriConsumer.class); - BiConsumer headerCb = callbacks.onHeader; + headerCallback = mock(TriConsumer.class); + BiConsumer capture = callbacks.onHeader; doAnswer( inv -> { - String name = inv.getArgument(1); - String value = inv.getArgument(2); - headerCb.accept(name, value); + capture.accept(inv.getArgument(1), inv.getArgument(2)); return null; }) - .when(mockHeaderCallback) - .accept(any(), any(), any()); + .when(headerCallback) + .accept(any(), anyString(), anyString()); } - TriFunction> mockSocketAddressCallback = null; + TriFunction> socketAddressCallback = null; if (callbacks.onSocketAddress != null) { - mockSocketAddressCallback = mock(TriFunction.class); - BiConsumer socketCb = callbacks.onSocketAddress; + socketAddressCallback = mock(TriFunction.class); + BiConsumer capture = callbacks.onSocketAddress; doAnswer( inv -> { - String ip = inv.getArgument(1); - Integer port = inv.getArgument(2); - socketCb.accept(ip, port); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1), (Integer) inv.getArgument(2)); + return Flow.ResultFlow.empty(); }) - .when(mockSocketAddressCallback) - .apply(any(), any(), any()); + .when(socketAddressCallback) + .apply(any(), anyString(), anyInt()); } - Function> mockHeaderDoneCallback = mock(Function.class); - when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + Function> headerDoneCallback = mock(Function.class); + when(headerDoneCallback.apply(any())).thenReturn(Flow.ResultFlow.empty()); - BiFunction, Flow> mockPathParamsCallback = null; + BiFunction, Flow> pathParamsCallback = null; if (callbacks.onPathParams != null) { - mockPathParamsCallback = mock(BiFunction.class); - Consumer> pathParamsCb = callbacks.onPathParams; + pathParamsCallback = mock(BiFunction.class); + Consumer> capture = callbacks.onPathParams; doAnswer( inv -> { - Map params = inv.getArgument(1); - pathParamsCb.accept(params); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1)); + return Flow.ResultFlow.empty(); }) - .when(mockPathParamsCallback) - .apply(any(), any()); + .when(pathParamsCallback) + .apply(any(), any(Map.class)); } - BiFunction> mockBodyCallback = null; + BiFunction> bodyCallback = null; if (callbacks.onBody != null) { - mockBodyCallback = mock(BiFunction.class); - Consumer bodyCb = callbacks.onBody; + bodyCallback = mock(BiFunction.class); + Consumer capture = callbacks.onBody; doAnswer( inv -> { - Object body = inv.getArgument(1); - bodyCb.accept(body); - return new Flow.ResultFlow<>(null); + capture.accept(inv.getArgument(1)); + return Flow.ResultFlow.empty(); }) - .when(mockBodyCallback) + .when(bodyCallback) .apply(any(), any()); } CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) - .thenReturn(mockRequestStartedCallback); + .thenReturn(requestStartedCallback); when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())) - .thenReturn(mockMethodUriCallback); - when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(mockHeaderCallback); + .thenReturn(methodUriCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(headerCallback); when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())) - .thenReturn(mockSocketAddressCallback); + .thenReturn(socketAddressCallback); when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) - .thenReturn(mockHeaderDoneCallback); + .thenReturn(headerDoneCallback); when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())) - .thenReturn(mockPathParamsCallback); - when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())) - .thenReturn(mockBodyCallback); + .thenReturn(pathParamsCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())).thenReturn(bodyCallback); AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) .thenReturn(mockCallbackProvider); - AgentTracer.forceRegister(mockTracer); } - private static String repeatChar(char ch, int count) { - char[] chars = new char[count]; - Arrays.fill(chars, ch); - return new String(chars); + private static Map mapOf(Object... keysAndValues) { + Map map = new LinkedHashMap<>(); + for (int i = 0; i < keysAndValues.length; i += 2) { + map.put((String) keysAndValues[i], keysAndValues[i + 1]); + } + return map; + } + + @SuppressWarnings("unchecked") + private AgentSpan setupMockResponseCallbacks( + Consumer onResponseStarted, + BiConsumer onResponseHeader, + Runnable onResponseHeaderDone, + Consumer onResponseBody) { + + RequestContext mockRequestContext = mock(RequestContext.class); + AgentSpan mockSpan = mock(AgentSpan.class); + when(mockSpan.getRequestContext()).thenReturn(mockRequestContext); + + BiFunction> responseStartedCb = null; + if (onResponseStarted != null) { + responseStartedCb = mock(BiFunction.class); + Consumer capture = onResponseStarted; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1)); + return new Flow.ResultFlow<>(null); + }) + .when(responseStartedCb) + .apply(any(RequestContext.class), anyInt()); + } + + TriConsumer responseHeaderCb = null; + if (onResponseHeader != null) { + responseHeaderCb = mock(TriConsumer.class); + BiConsumer capture = onResponseHeader; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1), inv.getArgument(2)); + return null; + }) + .when(responseHeaderCb) + .accept(any(), anyString(), anyString()); + } + + Function> responseHeaderDoneCb = mock(Function.class); + if (onResponseHeaderDone != null) { + Runnable capture = onResponseHeaderDone; + doAnswer( + inv -> { + capture.run(); + return new Flow.ResultFlow<>(null); + }) + .when(responseHeaderDoneCb) + .apply(any(RequestContext.class)); + } else { + when(responseHeaderDoneCb.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + } + + BiFunction> responseBodyCb = null; + if (onResponseBody != null) { + responseBodyCb = mock(BiFunction.class); + Consumer capture = onResponseBody; + doAnswer( + inv -> { + capture.accept(inv.getArgument(1)); + return new Flow.ResultFlow<>(null); + }) + .when(responseBodyCb) + .apply(any(RequestContext.class), any()); + } + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.responseStarted())).thenReturn(responseStartedCb); + when(mockCallbackProvider.getCallback(EVENTS.responseHeader())).thenReturn(responseHeaderCb); + when(mockCallbackProvider.getCallback(EVENTS.responseHeaderDone())) + .thenReturn(responseHeaderDoneCb); + when(mockCallbackProvider.getCallback(EVENTS.responseBody())).thenReturn(responseBodyCb); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + AgentTracer.forceRegister(mockTracer); + + return mockSpan; } } diff --git a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java index 14a5b5d6d21..8ef7b90a838 100644 --- a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java +++ b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java @@ -378,7 +378,7 @@ default AgentSpan blackholeSpan() { void notifyExtensionEnd(AgentSpan span, Object result, boolean isError, String lambdaRequestId); - void notifyAppSecEnd(AgentSpan span); + void notifyAppSecEnd(AgentSpan span, Object result); AgentDataStreamsMonitoring getDataStreamsMonitoring(); @@ -634,7 +634,7 @@ public void notifyExtensionEnd( AgentSpan span, Object result, boolean isError, String lambdaRequestId) {} @Override - public void notifyAppSecEnd(AgentSpan span) {} + public void notifyAppSecEnd(AgentSpan span, Object result) {} @Override public AgentDataStreamsMonitoring getDataStreamsMonitoring() {