From f894bf425c9b5be3800a32ffcad4d977a108620d Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:33:05 -0600 Subject: [PATCH 1/5] workload-id feature - initial commit --- .../com/azure/cosmos/CustomHeadersTests.java | 170 +++++++++ .../rntbd/RntbdWorkloadIdTests.java | 115 ++++++ .../azure/cosmos/rx/WorkloadIdE2ETests.java | 327 ++++++++++++++++++ .../com/azure/cosmos/CosmosAsyncClient.java | 1 + .../com/azure/cosmos/CosmosClientBuilder.java | 29 ++ .../implementation/AsyncDocumentClient.java | 9 +- .../cosmos/implementation/HttpConstants.java | 3 + .../implementation/RxDocumentClientImpl.java | 61 ++++ .../rntbd/RntbdConstants.java | 3 +- .../rntbd/RntbdRequestHeaders.java | 19 + .../models/CosmosBatchRequestOptions.java | 13 +- .../models/CosmosBulkExecutionOptions.java | 12 +- .../CosmosChangeFeedRequestOptions.java | 13 +- .../models/CosmosItemRequestOptions.java | 15 +- .../models/CosmosQueryRequestOptions.java | 16 + 15 files changed, 784 insertions(+), 22 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java new file mode 100644 index 000000000000..3c95c8bd2687 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. + *

+ * These tests verify the public API surface: builder fluent methods, getter behavior, + * null/empty handling, and that setHeader() is publicly accessible on all request options classes. + */ +public class CustomHeadersTests { + + /** + * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders() + * are stored correctly and retrievable via getCustomHeaders(). + */ + @Test(groups = { "unit" }) + public void customHeadersSetOnBuilder() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25"); + } + + /** + * Verifies that passing null to customHeaders() does not throw and that + * getCustomHeaders() returns null, ensuring graceful null handling. + */ + @Test(groups = { "unit" }) + public void customHeadersNullHandledGracefully() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(null); + + assertThat(builder.getCustomHeaders()).isNull(); + } + + /** + * Verifies that passing an empty map to customHeaders() is accepted and + * getCustomHeaders() returns an empty (not null) map. + */ + @Test(groups = { "unit" }) + public void customHeadersEmptyMapHandled() { + Map emptyHeaders = new HashMap<>(); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(emptyHeaders); + + assertThat(builder.getCustomHeaders()).isEmpty(); + } + + /** + * Verifies that multiple custom headers can be set at once on the builder and + * all entries are preserved and retrievable with correct keys and values. + */ + @Test(groups = { "unit" }) + public void multipleCustomHeadersSupported() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "15"); + headers.put("x-ms-custom-header", "value"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).hasSize(2); + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15"); + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value"); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on CRUD operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnItemRequestOptionsIsPublic() { + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "15"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on batch operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBatchRequestOptionsIsPublic() { + CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "20"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on change feed operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnChangeFeedRequestOptionsIsPublic() { + CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setHeader("x-ms-cosmos-workload-id", "25"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on bulk ingestion operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBulkExecutionOptionsIsPublic() { + CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() + .setHeader("x-ms-cosmos-workload-id", "30"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on query operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnQueryRequestOptionsIsPublic() { + CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "35"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined + * with the correct canonical header name "x-ms-cosmos-workload-id" as expected + * by the Cosmos DB service. + */ + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java new file mode 100644 index 000000000000..9ca123e16160 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.implementation.directconnectivity.rntbd; + +import com.azure.cosmos.implementation.HttpConstants; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants. + *

+ * + * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC), + * correct token type (Byte), is not required, and is not in the thin-client ordered header list + * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()). + */ +public class RntbdWorkloadIdTests { + + /** + * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders + * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and + * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping. + */ + @Test(groups = { "unit" }) + public void workloadIdConstantExists() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader + * with the correct wire ID (0x00DC). This ID is used to identify the header in the + * binary RNTBD protocol when communicating in Direct mode. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderExists() { + // Verify WorkloadId enum value exists with correct ID + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader).isNotNull(); + assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC); + } + + /** + * Verifies that the WorkloadId RNTBD header is defined as Byte token type, + * consistent with the ThroughputBucket pattern. The workload ID value (1-50) + * is encoded as a single byte on the wire. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsByteType() { + // Verify WorkloadId is Byte type (same as ThroughputBucket pattern) + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte); + } + + /** + * Verifies that WorkloadId is not a required RNTBD header. The header is optional — + * requests without a workload ID are valid and should not be rejected by the SDK. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsNotRequired() { + // WorkloadId should not be a required header + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.isRequired()).isFalse(); + } + + /** + * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client + * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is + * excluded from this list and will be auto-encoded in the second pass of + * RntbdTokenStream.encode() along with other non-ordered headers. + */ + @Test(groups = { "unit" }) + public void workloadIdNotInThinClientOrderedList() { + // WorkloadId should NOT be in thinClientHeadersInOrderList + // It will be automatically encoded in the second pass of RntbdTokenStream.encode() + assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList) + .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId); + } + + /** + * Verifies that valid workload ID values (1-50) can be parsed from String to int + * and cast to byte without data loss. Note: the SDK itself does not validate the + * range — this test confirms the encoding path works for expected values. + */ + @Test(groups = { "unit" }) + public void workloadIdValidValues() { + // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly + String[] validValues = {"1", "25", "50"}; + for (String value : validValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + assertThat(byteVal).isBetween((byte) 1, (byte) 50); + } + } + + /** + * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause + * exceptions in the SDK's parsing path. The SDK intentionally does not validate + * the range — invalid values are accepted and sent to the service, which silently + * ignores them. + */ + @Test(groups = { "unit" }) + public void workloadIdInvalidValuesAcceptedBySdk() { + // SDK does NOT validate range — service silently ignores invalid values + // These should not throw exceptions in SDK + String[] invalidValues = {"0", "51", "-1", "100"}; + for (String value : invalidValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + // SDK accepts any integer value that fits in a byte + assertThat(byteVal).isNotNull(); + } + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java new file mode 100644 index 000000000000..85b87090bb36 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.TestConfigurations; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosBulkOperations; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosItemResponse; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * End-to-end integration tests for the custom headers / workload-id feature. + *

+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + * These tests create a real database and container, then execute CRUD and query operations + * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level. + *

+ * What is verified: + * 1. CRUD operations succeed with client-level custom headers (workload-id) + * 2. Per-request header overrides work via setHeader() + * 3. Client with no custom headers continues to work (no regression) + * 4. Query operations succeed with workload-id + * 5. Empty headers and multiple headers are handled correctly + *

+ + */ +public class WorkloadIdE2ETests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + public WorkloadIdE2ETests() { + super(new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY)); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + clientWithWorkloadId = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * Smoke test: verifies that a create (POST) operation succeeds when the client + * has a workload-id custom header set at the builder level. Confirms the header + * flows through the request pipeline without causing errors. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithClientLevelWorkloadId() { + // Smoke test: verify create operation succeeds with client-level workload-id header + TestObject doc = TestObject.create(); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a read (GET) operation succeeds with the client-level workload-id + * header and that the correct document is returned. Ensures the header does not + * interfere with normal read semantics. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void readItemWithClientLevelWorkloadId() { + // Verify read operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getItem().getId()).isEqualTo(doc.getId()); + } + + /** + * Verifies that a replace (PUT) operation succeeds with the client-level workload-id + * header. Confirms the header propagates correctly for update operations. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void replaceItemWithClientLevelWorkloadId() { + // Verify replace operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + doc.setStringProp("updated-" + UUID.randomUUID()); + CosmosItemResponse response = container + .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + } + + /** + * Verifies that a delete operation succeeds with the client-level workload-id header + * and returns the expected 204 No Content status code. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void deleteItemWithClientLevelWorkloadId() { + // Verify delete operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(204); + } + + /** + * Verifies that a per-request workload-id header override via + * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header + * (value "30") should take precedence over the client-level default (value "15"). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override works — request-level should take precedence + TestObject doc = TestObject.create(); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30"); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), options) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a cross-partition query operation succeeds when the client has a + * workload-id custom header. Confirms the header flows correctly through the + * query pipeline and does not affect result correctness. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithClientLevelWorkloadId() { + // Verify query operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions(); + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Verifies that a per-request workload-id header override on + * {@code CosmosQueryRequestOptions.setHeader()} works for query operations. + * The request-level header (value "42") should take precedence over the + * client-level default. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override on query options works + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Regression test: verifies that a client created without any custom headers + * continues to work normally. Ensures the custom headers feature does not + * introduce regressions for clients that do not use it. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithNoCustomHeadersStillWorks() { + // Verify that a client without custom headers works normally (no regression) + CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithoutHeaders); + } + } + + /** + * Verifies that a client created with an empty custom headers map works normally. + * An empty map should behave identically to no custom headers — no errors, + * no unexpected behavior. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithEmptyCustomHeaders() { + // Verify that a client with empty custom headers map works normally + CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(new HashMap<>()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithEmptyHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithEmptyHeaders); + } + } + + /** + * Verifies that a client can be configured with multiple custom headers simultaneously + * (workload-id plus an additional custom header). Confirms that all headers flow + * through the pipeline without interfering with each other. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithMultipleCustomHeaders() { + // Verify that multiple custom headers can be set simultaneously + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); + headers.put("x-ms-custom-test-header", "test-value"); + + CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(headers) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithMultipleHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithMultipleHeaders); + } + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java index ec0dd64af008..f54f44482db5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java @@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable { .withDefaultSerializer(this.defaultCustomSerializer) .withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled()) .withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled()) + .withCustomHeaders(builder.getCustomHeaders()) .build(); this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 12d022e69ee7..aea282be566c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -37,6 +37,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; @@ -155,6 +156,7 @@ public class CosmosClientBuilder implements private boolean serverCertValidationDisabled = false; private Function containerFactory = null; + private Map customHeaders; /** * Instantiates a new Cosmos client builder. @@ -734,6 +736,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { return this; } + /** + * Sets custom HTTP headers that will be included with every request from this client. + *

+ * These headers are sent with all requests. For Direct/RNTBD mode, only known headers + * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers + * work only in Gateway mode. + *

+ * If the same header is also set on request options (e.g., + * {@code CosmosItemRequestOptions.setHeader(String, String)}), + * the request-level value takes precedence over the client-level value. + * + * @param customHeaders map of header name to value + * @return current CosmosClientBuilder + */ + public CosmosClientBuilder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + /** + * Gets the custom headers configured on this builder. + * @return the custom headers map, or null if not set + */ + Map getCustomHeaders() { + return this.customHeaders; + } + /** * Sets the retry policy options associated with the DocumentClient instance. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 03590c1f8a5d..7953721019c5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -116,6 +116,7 @@ class Builder { private boolean isRegionScopedSessionCapturingEnabled; private boolean isPerPartitionAutomaticFailoverEnabled; private List operationPolicies; + private Map customHeaders; public Builder withServiceEndpoint(String serviceEndpoint) { try { @@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu return this; } + public Builder withCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + private void ifThrowIllegalArgException(boolean value, String error) { if (value) { throw new IllegalArgumentException(error); @@ -328,7 +334,8 @@ public AsyncDocumentClient build() { defaultCustomSerializer, isRegionScopedSessionCapturingEnabled, operationPolicies, - isPerPartitionAutomaticFailoverEnabled); + isPerPartitionAutomaticFailoverEnabled, + customHeaders); client.init(state, null); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java index 4e283defbc1d..32378ef0cc8d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java @@ -298,6 +298,9 @@ public static class HttpHeaders { // Region affinity headers public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only"; + + // Workload ID header for Azure Monitor metrics attribution + public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id"; } public static class A_IMHeaderValues { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 192f8175978f..2f0bd4271d86 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -293,6 +293,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization private final AtomicReference cachedCosmosAsyncClientSnapshot; private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads; private Consumer perPartitionFailoverConfigModifier; + private Map customHeaders; public RxDocumentClientImpl(URI serviceEndpoint, String masterKeyOrResourceToken, @@ -366,6 +367,60 @@ public RxDocumentClientImpl(URI serviceEndpoint, boolean isRegionScopedSessionCapturingEnabled, List operationPolicies, boolean isPerPartitionAutomaticFailoverEnabled) { + this( + serviceEndpoint, + masterKeyOrResourceToken, + permissionFeed, + connectionPolicy, + consistencyLevel, + readConsistencyStrategy, + configs, + cosmosAuthorizationTokenResolver, + credential, + tokenCredential, + sessionCapturingOverride, + connectionSharingAcrossClientsEnabled, + contentResponseOnWriteEnabled, + metadataCachesSnapshot, + apiType, + clientTelemetryConfig, + clientCorrelationId, + cosmosEndToEndOperationLatencyPolicyConfig, + sessionRetryOptions, + containerProactiveInitConfig, + defaultCustomSerializer, + isRegionScopedSessionCapturingEnabled, + operationPolicies, + isPerPartitionAutomaticFailoverEnabled, + null + ); + } + + public RxDocumentClientImpl(URI serviceEndpoint, + String masterKeyOrResourceToken, + List permissionFeed, + ConnectionPolicy connectionPolicy, + ConsistencyLevel consistencyLevel, + ReadConsistencyStrategy readConsistencyStrategy, + Configs configs, + CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver, + AzureKeyCredential credential, + TokenCredential tokenCredential, + boolean sessionCapturingOverride, + boolean connectionSharingAcrossClientsEnabled, + boolean contentResponseOnWriteEnabled, + CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, + ApiType apiType, + CosmosClientTelemetryConfig clientTelemetryConfig, + String clientCorrelationId, + CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig, + SessionRetryOptions sessionRetryOptions, + CosmosContainerProactiveInitConfig containerProactiveInitConfig, + CosmosItemSerializer defaultCustomSerializer, + boolean isRegionScopedSessionCapturingEnabled, + List operationPolicies, + boolean isPerPartitionAutomaticFailoverEnabled, + Map customHeaders) { this( serviceEndpoint, masterKeyOrResourceToken, @@ -392,6 +447,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver; this.operationPolicies = operationPolicies; + this.customHeaders = customHeaders; } private RxDocumentClientImpl(URI serviceEndpoint, @@ -1884,6 +1940,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) { Map headers = new HashMap<>(); + // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders()) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + headers.putAll(this.customHeaders); + } + if (this.useMultipleWriteLocations) { headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString()); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index ba3ec8d2017d..d79231793679 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -598,7 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader { PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), - HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); + HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false), + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false); public static final List thinClientHeadersInOrderList = Arrays.asList( EffectivePartitionKey, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index 6f6e46ee695d..387e8cf3ed5a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonFilter; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { // region Fields + private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class); private static final String URL_TRIM = "/"; // endregion @@ -134,6 +137,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { this.addGlobalDatabaseAccountName(headers); this.addThroughputBucket(headers); this.addHubRegionProcessingOnly(headers); + this.addWorkloadId(headers); // Normal headers (Strings, Ints, Longs, etc.) @@ -297,6 +301,8 @@ private RntbdToken getCorrelatedActivityId() { private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); } + private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); } + private RntbdToken getGlobalDatabaseAccountName() { return this.get(RntbdRequestHeader.GlobalDatabaseAccountName); } @@ -816,6 +822,19 @@ private void addHubRegionProcessingOnly(final Map headers) { } } + private void addWorkloadId(final Map headers) { + final String value = headers.get(HttpHeaders.WORKLOAD_ID); + + if (StringUtils.isNotEmpty(value)) { + try { + final int workloadId = Integer.valueOf(value); + this.getWorkloadId().setValue((byte) workloadId); + } catch (NumberFormatException e) { + logger.warn("Invalid value for workload id header: {}", value, e); + } + } + } + private void addGlobalDatabaseAccountName(final Map headers) { final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index 7d5a27324f95..3183fe59bdea 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -154,14 +154,17 @@ RequestOptions toRequestOptions() { } /** - * Sets the custom batch request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBatchRequestOptions. */ - CosmosBatchRequestOptions setHeader(String name, String value) { + public CosmosBatchRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index f125c02d6725..cd688f8a0da6 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -257,13 +257,17 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat } /** - * Sets the custom bulk request option value by key + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBulkExecutionOptions. */ - CosmosBulkExecutionOptions setHeader(String name, String value) { + public CosmosBulkExecutionOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index 3ac526de6d63..a1b675f2ffd8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -564,14 +564,17 @@ public List getExcludedRegions() { } /** - * Sets the custom change feed request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosChangeFeedRequestOptions. */ - CosmosChangeFeedRequestOptions setHeader(String name, String value) { + public CosmosChangeFeedRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index 72eb108a6428..fbc540e5baeb 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -566,14 +566,17 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre } /** - * Sets the custom item request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value - * + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosItemRequestOptions. */ - CosmosItemRequestOptions setHeader(String name, String value) { + public CosmosItemRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index 7ead6e208781..f0de81bbf823 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -260,6 +260,22 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) return this; } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosQueryRequestOptions. + */ + public CosmosQueryRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + /** * Gets the list of regions to exclude for the request/retries. These regions are excluded * from the preferred region list. From a75ab7b1bd494dfec12ad501b23923a101b1cba4 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:57:37 -0600 Subject: [PATCH 2/5] workload-id feature - initial commit(Spark) --- .../cosmos/spark/CosmosClientCache.scala | 15 +- .../spark/CosmosClientConfiguration.scala | 8 +- .../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++- .../cosmos/spark/CosmosClientCacheITest.scala | 17 +- .../spark/CosmosClientConfigurationSpec.scala | 68 ++++++++ .../spark/CosmosPartitionPlannerSpec.scala | 24 ++- .../cosmos/spark/PartitionMetadataSpec.scala | 48 ++++-- .../spark/SparkE2EWorkloadIdITest.scala | 150 ++++++++++++++++++ 8 files changed, 323 insertions(+), 38 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index e61a271aeb8b..9ad739d8f3f6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import java.util.function.BiPredicate import scala.collection.concurrent.TrieMap - // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import @@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + // Apply custom HTTP headers (e.g., workload-id) to the builder if configured. + // These headers are attached to every Cosmos DB request made by this client instance. + if (cosmosClientConfiguration.customHeaders.isDefined) { + builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava) + } + var client = builder.buildAsyncClient() if (cosmosClientConfiguration.clientInterceptors.isDefined) { @@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Custom HTTP headers are part of the cache key because different workload-ids + // should produce different CosmosAsyncClient instances + customHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientBuilderInterceptors, clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, - clientConfig.azureMonitorConfig + clientConfig.azureMonitorConfig, + clientConfig.customHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 6f4e26e1f503..61fa0957af83 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration ( clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Optional custom HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.customHeaders() + customHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientBuilderInterceptors, cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, - diagnosticsConfig.azureMonitorConfig + diagnosticsConfig.azureMonitorConfig, + cosmosAccountConfig.customHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index eef3f6ae1f8d..6646b2e69ae5 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter import java.time.{Duration, Instant} import java.util import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable @@ -150,6 +151,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" + // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} + // Flows through to CosmosClientBuilder.customHeaders(). + val CustomHeaders = "spark.cosmos.customHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -295,7 +300,8 @@ private[spark] object CosmosConfigNames { WriteOnRetryCommitInterceptor, WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, - WriteMaxRetryNoProgressIntervalInSeconds + WriteMaxRetryNoProgressIntervalInSeconds, + CustomHeaders ) def validateConfigName(name: String): Unit = { @@ -538,7 +544,10 @@ private case class CosmosAccountConfig(endpoint: String, resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], - clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + // Optional custom HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder + customHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -719,6 +728,19 @@ private object CosmosAccountConfig extends BasicLoggingTrait { parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like + // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. + // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache. + private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.CustomHeaders, + mandatory = false, + parseFromStringFunction = headersJson => { + val mapper = new com.fasterxml.jackson.databind.ObjectMapper() + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + mapper.readValue(headersJson, typeRef).asScala.toMap + }, + helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -753,6 +775,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) + // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -864,7 +888,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { resourceGroupNameOpt, azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, - if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, + customHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index ccf36791dc96..4d542c44612e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -64,7 +64,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -91,7 +92,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -118,7 +120,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -145,7 +148,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ) ) @@ -179,8 +183,9 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None - ) + azureMonitorConfig = None, + customHeaders = None + ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 7fcc601ba016..a0627c0cf3dd 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -408,4 +408,72 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ") configuration.azureMonitorConfig shouldEqual None } + + // Verifies that the spark.cosmos.customHeaders configuration option correctly parses + // a JSON string containing a single workload-id header into a Map[String, String] on + // CosmosClientConfiguration. This is the primary use case for the workload-id feature. + it should "parse customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + } + + // Verifies that when spark.cosmos.customHeaders is not specified in the config map, + // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set customHeaders continue to work without changes. + it should "handle missing customHeaders" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe None + } + + // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing + // multiple headers into a Map with all entries preserved. This supports use cases where + // multiple custom headers need to be sent alongside workload-id. + it should "parse multiple custom headers" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get should have size 2 + configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20" + configuration.customHeaders.get("x-custom-header") shouldEqual "value" + } + + // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, + // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases + // and that no headers are injected when the JSON object is empty. + it should "handle empty customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> "{}" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get shouldBe empty + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 6ef90b55989d..ab73dc4e54d3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index dfd14c36c80f..65274bee2b19 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala new file mode 100644 index 000000000000..d9706d0709e5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +/** + * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector. + * + * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows + * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that + * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations + * initiated via Spark DataFrames (reads and writes). + * + * Requires the Cosmos DB Emulator running + */ +class SparkE2EWorkloadIdITest + extends IntegrationSpec + with Spark + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait { + + val objectMapper = new ObjectMapper() + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The header should be passed through to the + // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior. + "spark query with customHeaders" can "read items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } + + // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The item is written via Spark and then verified + // via a direct SDK read to confirm the write was persisted correctly. + "spark write with customHeaders" can "write items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testWriteItem" + | } + |""".stripMargin + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.write.strategy" -> "ItemOverwrite", + "spark.cosmos.write.bulk.enabled" -> "false", + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + val spark_session = spark + import spark_session.implicits._ + val df = spark.read.json(Seq(rawItem).toDS()) + + df.write.format("cosmos.oltp").options(cfg).mode("Append").save() + + // Verify the item was written by reading it back via the SDK directly + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block() + readItem.getItem.get("id").textValue() shouldEqual id + readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" + } + + // Regression test: verifies that Spark read operations continue to work correctly when + // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not + // break existing behavior for clients that do not use custom headers. + "spark operations without customHeaders" can "still succeed" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "noHeadersItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + rowsArray(0).getAs[String]("id") shouldEqual id + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals + //scalastyle:on null +} From 1bbf6b9bb962fec1f576fbf3a4b0863a0a132099 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:08:04 -0600 Subject: [PATCH 3/5] workload-id feature - cleaning up comments --- .../com/azure/cosmos/rx/WorkloadIdE2ETests.java | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index 85b87090bb36..a57b4d9d9a0b 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -32,17 +32,6 @@ * End-to-end integration tests for the custom headers / workload-id feature. *

* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. - * These tests create a real database and container, then execute CRUD and query operations - * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level. - *

- * What is verified: - * 1. CRUD operations succeed with client-level custom headers (workload-id) - * 2. Per-request header overrides work via setHeader() - * 3. Client with no custom headers continues to work (no regression) - * 4. Query operations succeed with workload-id - * 5. Empty headers and multiple headers are handled correctly - *

- */ public class WorkloadIdE2ETests extends TestSuiteBase { @@ -82,13 +71,12 @@ public void beforeClass() { } /** - * Smoke test: verifies that a create (POST) operation succeeds when the client + * verifies that a create (POST) operation succeeds when the client * has a workload-id custom header set at the builder level. Confirms the header * flows through the request pipeline without causing errors. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void createItemWithClientLevelWorkloadId() { - // Smoke test: verify create operation succeeds with client-level workload-id header TestObject doc = TestObject.create(); CosmosItemResponse response = container From 7aeb8b4e030958f77e14983fc5f0d3475333effb Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:20:36 -0600 Subject: [PATCH 4/5] workload-id feature - addressing copilot comments --- sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + .../directconnectivity/rntbd/RntbdConstants.java | 4 ++-- .../directconnectivity/rntbd/RntbdRequestHeaders.java | 2 +- 8 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index de77b69485a8..2eeccc69dada 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 80072357c58f..6e72286dfab6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index b905a025a1e6..d0d80af466d6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index e17cecdfdac2..ae444a9c399f 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index e9be63ef89bd..ae910280495b 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 8475b83dc5ee..ea4fdb82e1dc 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -4,6 +4,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index d79231793679..d75bf5dc88e1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -598,8 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader { PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), - HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false), - WorkloadId((short)0x00DC, RntbdTokenType.Byte, false); + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false), + HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); public static final List thinClientHeadersInOrderList = Arrays.asList( EffectivePartitionKey, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index 387e8cf3ed5a..46f8060387fc 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -827,7 +827,7 @@ private void addWorkloadId(final Map headers) { if (StringUtils.isNotEmpty(value)) { try { - final int workloadId = Integer.valueOf(value); + final int workloadId = Integer.parseInt(value); this.getWorkloadId().setValue((byte) workloadId); } catch (NumberFormatException e) { logger.warn("Invalid value for workload id header: {}", value, e); From 6d00d49320caa42432aa26725befe529e5934ad4 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:38:14 -0600 Subject: [PATCH 5/5] workload-id feature - addressing comments --- .../com/azure/cosmos/spark/CosmosConfig.scala | 13 +- .../spark/CosmosClientConfigurationSpec.scala | 13 +- .../com/azure/cosmos/CustomHeadersTests.java | 104 ++++++++++- .../RxDocumentClientUnderTest.java | 7 +- .../RxGatewayStoreModelTest.java | 172 +++++++++++++++++- .../SpyClientUnderTestFactory.java | 7 +- .../GatewayAddressCacheTest.java | 125 +++++++++++++ .../GlobalAddressResolverTest.java | 3 +- .../azure/cosmos/rx/WorkloadIdE2ETests.java | 62 ++----- .../com/azure/cosmos/CosmosClientBuilder.java | 50 ++++- .../implementation/RxDocumentClientImpl.java | 12 +- .../implementation/RxGatewayStoreModel.java | 17 +- .../implementation/ThinClientStoreModel.java | 3 +- .../GatewayAddressCache.java | 17 +- .../GlobalAddressResolver.java | 8 +- .../models/CosmosReadManyRequestOptions.java | 16 ++ 16 files changed, 546 insertions(+), 83 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 6646b2e69ae5..62802d23b14c 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper -import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} +import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition} import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} @@ -735,9 +735,14 @@ private object CosmosAccountConfig extends BasicLoggingTrait { key = CosmosConfigNames.CustomHeaders, mandatory = false, parseFromStringFunction = headersJson => { - val mapper = new com.fasterxml.jackson.databind.ObjectMapper() - val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} - mapper.readValue(headersJson, typeRef).asScala.toMap + try { + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + } catch { + case e: Exception => throw new IllegalArgumentException( + s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " + + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) + } }, helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index a0627c0cf3dd..377425189f07 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -441,10 +441,11 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.customHeaders shouldBe None } - // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing - // multiple headers into a Map with all entries preserved. This supports use cases where - // multiple custom headers need to be sent alongside workload-id. - it should "parse multiple custom headers" in { + // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level. + // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD), + // unknown headers are silently dropped, so the allowlist ensures consistent behavior + // across Gateway and Direct modes. + it should "reject unknown custom headers" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", @@ -454,10 +455,10 @@ class CosmosClientConfigurationSpec extends UnitSpec { val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is. + // The allowlist validation happens later in CosmosClientBuilder.customHeaders() configuration.customHeaders shouldBe defined configuration.customHeaders.get should have size 2 - configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20" - configuration.customHeaders.get("x-custom-header") shouldEqual "value" } // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java index 3c95c8bd2687..19eb03744d1a 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java @@ -9,6 +9,7 @@ import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; import com.azure.cosmos.models.CosmosItemRequestOptions; import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; import com.azure.cosmos.models.FeedRange; import org.testng.annotations.Test; @@ -16,6 +17,7 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. @@ -73,23 +75,40 @@ public void customHeadersEmptyMapHandled() { } /** - * Verifies that multiple custom headers can be set at once on the builder and - * all entries are preserved and retrievable with correct keys and values. + * Verifies that headers not in the allowlist are rejected with IllegalArgumentException. + * This ensures consistent behavior across Gateway and Direct modes — only headers with + * RNTBD encoding support are allowed. */ @Test(groups = { "unit" }) - public void multipleCustomHeadersSupported() { + public void unknownHeaderRejectedByAllowlist() { Map headers = new HashMap<>(); - headers.put("x-ms-cosmos-workload-id", "15"); headers.put("x-ms-custom-header", "value"); - CosmosClientBuilder builder = new CosmosClientBuilder() + assertThatThrownBy(() -> new CosmosClientBuilder() .endpoint("https://test.documents.azure.com:443/") .key("dGVzdEtleQ==") - .customHeaders(headers); + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header") + .hasMessageContaining("not allowed"); + } + + /** + * Verifies that a map containing both an allowed header and a disallowed header + * is rejected — the entire map must pass the allowlist check. + */ + @Test(groups = { "unit" }) + public void mixedAllowedAndDisallowedHeadersRejected() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "15"); + headers.put("x-ms-custom-header", "value"); - assertThat(builder.getCustomHeaders()).hasSize(2); - assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15"); - assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value"); + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header"); } /** @@ -158,6 +177,19 @@ public void setHeaderOnQueryRequestOptionsIsPublic() { assertThat(options).isNotNull(); } + /** + * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on read-many operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnReadManyRequestOptionsIsPublic() { + CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "40"); + + assertThat(options).isNotNull(); + } + /** * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined * with the correct canonical header name "x-ms-cosmos-workload-id" as expected @@ -167,4 +199,58 @@ public void setHeaderOnQueryRequestOptionsIsPublic() { public void workloadIdHttpHeaderConstant() { assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); } + + /** + * Verifies that a non-numeric workload-id value is rejected at builder level with + * IllegalArgumentException. This covers both Gateway and Direct modes consistently + * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. + * Range validation [1, 50] is the backend's responsibility — the SDK only validates + * that the value is a valid integer. This avoids hardcoding a range the backend team + * might change in the future. + */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + } + + /** + * Verifies that a valid workload-id value passes builder validation. + */ + @Test(groups = { "unit" }) + public void validWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java index a9f5cb35549c..d5f8b92ac7a6 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.mockito.Mockito.doAnswer; @@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy( GlobalEndpointManager globalEndpointManager, GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker, HttpClient rxOrigClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { origHttpClient = rxOrigClient; spyHttpClient = Mockito.spy(rxOrigClient); @@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy( userAgentContainer, globalEndpointManager, spyHttpClient, - apiType); + apiType, + customHeaders); } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java index 54440ecfabc5..587844f4043a 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java @@ -27,6 +27,8 @@ import java.net.SocketException; import java.net.URI; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext; @@ -102,6 +104,7 @@ public void readTimeout() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -205,7 +209,8 @@ public void applySessionToken( new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( @@ -277,7 +282,8 @@ public void validateApiType() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception { return false; } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are injected into + * outgoing HTTP requests by performRequest(). This covers metadata requests + * (collection cache, partition key range) that don't go through getRequestHeaders(). + */ + @Test(groups = "unit") + public void customHeadersInjectedInPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25"); + } + + /** + * Verifies that request-level headers take precedence over client-level customHeaders. + * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest() + * should NOT overwrite it. + */ + @Test(groups = "unit") + public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col/docs/doc1", + ResourceType.Document); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + // Simulate request-level header already set (e.g., by getRequestHeaders()) + dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // Request-level header "42" should win over client-level "10" + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42"); + } + + /** + * Verifies that when customHeaders is null, performRequest() still works normally + * without injecting any extra headers. + */ + @Test(groups = "unit") + public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // No workload-id header should be present + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull(); + } + enum SessionTokenType { NONE, // no session token applied USER, // userControlled session token diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java index b06d6f89b8e9..775b74785630 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.stream.Collectors; @@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient rxClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.origRxGatewayStoreModel = super.createRxGatewayProxy( sessionContainer, consistencyLevel, @@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, rxClient, - apiType); + apiType, + customHeaders); this.requests = Collections.synchronizedList(new ArrayList<>()); this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel); this.initRequestCapture(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java index 172c00f799bc..9b938d0a1520 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java @@ -57,6 +57,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds, null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti null, null, null, + null, null); RxDocumentServiceRequest req = @@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh() null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask); @@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); // connected status @@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs return new HttpClientUnderTestWrapper(origHttpClient); } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are included in + * GatewayAddressCache's defaultRequestHeaders, which are sent on every address + * resolution request. + */ + @Test(groups = { "unit" }) + public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) + * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers + * set before customHeaders are preserved. + */ + @Test(groups = { "unit" }) + public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); + customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // SDK headers should NOT be overwritten + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent"); + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION); + // Custom header should still be added + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders + * contains only SDK system headers and no extra entries. + */ + @Test(groups = { "unit" }) + public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + null); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // Should only contain SDK system headers, no workload-id + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT); + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION); + assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID); + } + public String getNameBasedCollectionLink() { return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId(); } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java index 331be53cc7af..5879e7d3e61c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java @@ -110,7 +110,7 @@ public void resolveAsync() throws Exception { GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider, userAgentContainer, - serviceConfigReader, connectionPolicy, null); + serviceConfigReader, connectionPolicy, null, null); RxDocumentServiceRequest request; request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(), OperationType.Read, @@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() { userAgentContainer, serviceConfigReader, connectionPolicy, + null, null); GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache(); GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index a57b4d9d9a0b..3bf2fdafce7c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -8,9 +8,6 @@ import com.azure.cosmos.CosmosClientBuilder; import com.azure.cosmos.TestObject; import com.azure.cosmos.implementation.HttpConstants; -import com.azure.cosmos.implementation.TestConfigurations; -import com.azure.cosmos.models.CosmosBulkExecutionOptions; -import com.azure.cosmos.models.CosmosBulkOperations; import com.azure.cosmos.models.CosmosContainerProperties; import com.azure.cosmos.models.CosmosItemRequestOptions; import com.azure.cosmos.models.CosmosItemResponse; @@ -19,6 +16,7 @@ import com.azure.cosmos.models.PartitionKeyDefinition; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; import org.testng.annotations.Test; import java.util.ArrayList; @@ -32,6 +30,10 @@ * End-to-end integration tests for the custom headers / workload-id feature. *

* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests + * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC), + * ensuring the workload-id header is correctly encoded and sent in both transport paths. */ public class WorkloadIdE2ETests extends TestSuiteBase { @@ -42,10 +44,9 @@ public class WorkloadIdE2ETests extends TestSuiteBase { private CosmosAsyncDatabase database; private CosmosAsyncContainer container; - public WorkloadIdE2ETests() { - super(new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY)); + @Factory(dataProvider = "simpleClientBuilderGatewaySession") + public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); } @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) @@ -53,9 +54,7 @@ public void beforeClass() { Map headers = new HashMap<>(); headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); - clientWithWorkloadId = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + clientWithWorkloadId = getClientBuilder() .customHeaders(headers) .buildAsyncClient(); @@ -218,9 +217,7 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() { @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void clientWithNoCustomHeadersStillWorks() { // Verify that a client without custom headers works normally (no regression) - CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) .buildAsyncClient(); try { @@ -248,9 +245,7 @@ public void clientWithNoCustomHeadersStillWorks() { @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void clientWithEmptyCustomHeaders() { // Verify that a client with empty custom headers map works normally - CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder()) .customHeaders(new HashMap<>()) .buildAsyncClient(); @@ -272,38 +267,19 @@ public void clientWithEmptyCustomHeaders() { } /** - * Verifies that a client can be configured with multiple custom headers simultaneously - * (workload-id plus an additional custom header). Confirms that all headers flow - * through the pipeline without interfering with each other. + * Verifies that unknown headers in customHeaders are rejected by the allowlist. + * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist + * ensures consistent behavior across Gateway and Direct modes. */ - @Test(groups = { "emulator" }, timeOut = TIMEOUT) - public void clientWithMultipleCustomHeaders() { - // Verify that multiple custom headers can be set simultaneously + @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void unknownCustomHeadersRejectedByAllowlist() { Map headers = new HashMap<>(); headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); headers.put("x-ms-custom-test-header", "test-value"); - CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) - .customHeaders(headers) - .buildAsyncClient(); - - try { - CosmosAsyncContainer c = clientWithMultipleHeaders - .getDatabase(DATABASE_ID) - .getContainer(CONTAINER_ID); - - TestObject doc = TestObject.create(); - CosmosItemResponse response = c - .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) - .block(); - - assertThat(response).isNotNull(); - assertThat(response.getStatusCode()).isEqualTo(201); - } finally { - safeClose(clientWithMultipleHeaders); - } + // Should throw IllegalArgumentException due to unknown header + copyCosmosClientBuilder(getClientBuilder()) + .customHeaders(headers); } @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index aea282be566c..e4cdf6ca1e30 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -13,6 +13,7 @@ import com.azure.cosmos.implementation.ConnectionPolicy; import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot; import com.azure.cosmos.implementation.DiagnosticsProvider; +import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.implementation.Strings; import com.azure.cosmos.implementation.WriteRetryPolicy; import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList; @@ -158,6 +159,20 @@ public class CosmosClientBuilder implements private Function containerFactory = null; private Map customHeaders; + /** + * Allowlist of headers permitted in {@link #customHeaders(Map)}. + *

+ * In Direct mode (RNTBD), only headers with explicit encoding support in + * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped. + * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header + * is allowed here, it works in both modes. To add a new allowed header, you must also add + * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry + + * {@code RntbdRequestHeaders.addXxx()} method). + */ + private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet( + new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID)) + ); + /** * Instantiates a new Cosmos client builder. */ @@ -739,9 +754,13 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { /** * Sets custom HTTP headers that will be included with every request from this client. *

- * These headers are sent with all requests. For Direct/RNTBD mode, only known headers - * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers - * work only in Gateway mode. + * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is + * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw + * {@link IllegalArgumentException}. + *

+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit + * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist + * ensures consistent behavior across both Gateway and Direct modes. *

* If the same header is also set on request options (e.g., * {@code CosmosItemRequestOptions.setHeader(String, String)}), @@ -749,8 +768,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { * * @param customHeaders map of header name to value * @return current CosmosClientBuilder + * @throws IllegalArgumentException if any header key is not in the allowlist, or if the + * workload-id value is not a valid integer */ public CosmosClientBuilder customHeaders(Map customHeaders) { + if (customHeaders != null) { + for (Map.Entry entry : customHeaders.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + + if (!ALLOWED_CUSTOM_HEADERS.contains(key)) { + throw new IllegalArgumentException( + "Header '" + key + "' is not allowed in customHeaders. " + + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS); + } + + // Validate workload-id value is a valid integer (range validation is left to the backend) + if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) { + try { + Integer.parseInt(value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + value + "' for header '" + key + + "'. The value must be a valid integer.", e); + } + } + } + } this.customHeaders = customHeaders; return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 2f0bd4271d86..122542a8810a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -863,7 +863,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.userAgentContainer, this.globalEndpointManager, this.reactorHttpClient, - this.apiType); + this.apiType, + this.customHeaders); this.thinProxy = createThinProxy(this.sessionContainer, this.consistencyLevel, @@ -969,7 +970,8 @@ private void initializeDirectConnectivity() { // this.gatewayConfigurationReader, null, this.connectionPolicy, - this.apiType); + this.apiType, + this.customHeaders); this.storeClientFactory = new StoreClientFactory( this.addressResolver, @@ -1013,7 +1015,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { return new RxGatewayStoreModel( this, sessionContainer, @@ -1022,7 +1025,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, httpClient, - apiType); + apiType, + customHeaders); } ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java index 979c528b32bb..a197723f0ae3 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java @@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize private GatewayServiceConfigurationReader gatewayServiceConfigurationReader; private RxClientCollectionCache collectionCache; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public RxGatewayStoreModel( DiagnosticsClientContext clientContext, @@ -100,7 +101,8 @@ public RxGatewayStoreModel( UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.clientContext = clientContext; @@ -116,6 +118,7 @@ public RxGatewayStoreModel( this.httpClient = httpClient; this.sessionContainer = sessionContainer; + this.customHeaders = customHeaders; } public RxGatewayStoreModel(RxGatewayStoreModel inner) { @@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) { this.httpClient = inner.httpClient; this.sessionContainer = inner.sessionContainer; + this.customHeaders = inner.customHeaders; } protected Map getDefaultHeaders( @@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics(); } + // Apply client-level custom headers (e.g., workload-id) to all requests + // including metadata requests (collection cache, partition key range, etc.) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + for (Map.Entry entry : this.customHeaders.entrySet()) { + // Only set if not already present — request-level headers take precedence + if (!request.getHeaders().containsKey(entry.getKey())) { + request.getHeaders().put(entry.getKey(), entry.getValue()); + } + } + } + URI uri = getUri(request); request.requestContext.resourcePhysicalAddress = uri.toString(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java index d32e5d901f18..ff139e203d2e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java @@ -56,7 +56,8 @@ public ThinClientStoreModel( userAgentContainer, globalEndpointManager, httpClient, - ApiType.SQL); + ApiType.SQL, + null); String userAgent = userAgentContainer != null ? userAgentContainer.getUserAgent() diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index e62d7b8c6ca4..7c761335b782 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -123,7 +123,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this.clientContext = clientContext; try { @@ -165,6 +166,14 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); + // Apply client-level custom headers (e.g., workload-id) to metadata requests + // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten + if (customHeaders != null && !customHeaders.isEmpty()) { + for (Map.Entry entry : customHeaders.entrySet()) { + this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + this.lastForcedRefreshMap = new ConcurrentHashMap<>(); this.globalEndpointManager = globalEndpointManager; this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor; @@ -188,7 +197,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this(clientContext, serviceEndpoint, protocol, @@ -200,7 +210,8 @@ public GatewayAddressCache( globalEndpointManager, connectionPolicy, proactiveOpenConnectionsProcessor, - gatewayServerErrorInjector); + gatewayServerErrorInjector, + customHeaders); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java index 00905682b4d1..2fd5287da028 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java @@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver { private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor; private ConnectionPolicy connectionPolicy; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public GlobalAddressResolver( DiagnosticsClientContext diagnosticsClientContext, @@ -74,7 +75,8 @@ public GlobalAddressResolver( UserAgentContainer userAgentContainer, GatewayServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.diagnosticsClientContext = diagnosticsClientContext; this.httpClient = httpClient; this.endpointManager = endpointManager; @@ -86,6 +88,7 @@ public GlobalAddressResolver( this.serviceConfigReader = serviceConfigReader; this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled(); this.connectionPolicy = connectionPolicy; + this.customHeaders = customHeaders; int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0; this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover) @@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) { this.endpointManager, this.connectionPolicy, this.proactiveOpenConnectionsProcessor, - this.gatewayServerErrorInjector); + this.gatewayServerErrorInjector, + this.customHeaders); AddressResolver addressResolver = new AddressResolver(); addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache); EndpointCache cache = new EndpointCache(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index f6e570258042..de2d769f789b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -366,6 +366,22 @@ public Set getKeywordIdentifiers() { return this.actualRequestOptions.getKeywordIdentifiers(); } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosReadManyRequestOptions. + */ + public CosmosReadManyRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + CosmosQueryRequestOptionsBase getImpl() { return this.actualRequestOptions; }