From f894bf425c9b5be3800a32ffcad4d977a108620d Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 17:33:05 -0600
Subject: [PATCH 1/5] workload-id feature - initial commit
---
.../com/azure/cosmos/CustomHeadersTests.java | 170 +++++++++
.../rntbd/RntbdWorkloadIdTests.java | 115 ++++++
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 327 ++++++++++++++++++
.../com/azure/cosmos/CosmosAsyncClient.java | 1 +
.../com/azure/cosmos/CosmosClientBuilder.java | 29 ++
.../implementation/AsyncDocumentClient.java | 9 +-
.../cosmos/implementation/HttpConstants.java | 3 +
.../implementation/RxDocumentClientImpl.java | 61 ++++
.../rntbd/RntbdConstants.java | 3 +-
.../rntbd/RntbdRequestHeaders.java | 19 +
.../models/CosmosBatchRequestOptions.java | 13 +-
.../models/CosmosBulkExecutionOptions.java | 12 +-
.../CosmosChangeFeedRequestOptions.java | 13 +-
.../models/CosmosItemRequestOptions.java | 15 +-
.../models/CosmosQueryRequestOptions.java | 16 +
15 files changed, 784 insertions(+), 22 deletions(-)
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
new file mode 100644
index 000000000000..3c95c8bd2687
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
@@ -0,0 +1,170 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.models.CosmosBatchRequestOptions;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.FeedRange;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
+ *
+ * These tests verify the public API surface: builder fluent methods, getter behavior,
+ * null/empty handling, and that setHeader() is publicly accessible on all request options classes.
+ */
+public class CustomHeadersTests {
+
+ /**
+ * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders()
+ * are stored correctly and retrievable via getCustomHeaders().
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersSetOnBuilder() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "25");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25");
+ }
+
+ /**
+ * Verifies that passing null to customHeaders() does not throw and that
+ * getCustomHeaders() returns null, ensuring graceful null handling.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersNullHandledGracefully() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(null);
+
+ assertThat(builder.getCustomHeaders()).isNull();
+ }
+
+ /**
+ * Verifies that passing an empty map to customHeaders() is accepted and
+ * getCustomHeaders() returns an empty (not null) map.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersEmptyMapHandled() {
+ Map emptyHeaders = new HashMap<>();
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(emptyHeaders);
+
+ assertThat(builder.getCustomHeaders()).isEmpty();
+ }
+
+ /**
+ * Verifies that multiple custom headers can be set at once on the builder and
+ * all entries are preserved and retrievable with correct keys and values.
+ */
+ @Test(groups = { "unit" })
+ public void multipleCustomHeadersSupported() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "15");
+ headers.put("x-ms-custom-header", "value");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).hasSize(2);
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15");
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value");
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on CRUD operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnItemRequestOptionsIsPublic() {
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "15");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on batch operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBatchRequestOptionsIsPublic() {
+ CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "20");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on change feed operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnChangeFeedRequestOptionsIsPublic() {
+ CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
+ .createForProcessingFromBeginning(FeedRange.forFullRange())
+ .setHeader("x-ms-cosmos-workload-id", "25");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on bulk ingestion operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBulkExecutionOptionsIsPublic() {
+ CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
+ .setHeader("x-ms-cosmos-workload-id", "30");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on query operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnQueryRequestOptionsIsPublic() {
+ CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "35");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
+ * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
+ * by the Cosmos DB service.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdHttpHeaderConstant() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
new file mode 100644
index 000000000000..9ca123e16160
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.implementation.directconnectivity.rntbd;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import org.testng.annotations.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants.
+ *
+ *
+ * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC),
+ * correct token type (Byte), is not required, and is not in the thin-client ordered header list
+ * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()).
+ */
+public class RntbdWorkloadIdTests {
+
+ /**
+ * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders
+ * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and
+ * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdConstantExists() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader
+ * with the correct wire ID (0x00DC). This ID is used to identify the header in the
+ * binary RNTBD protocol when communicating in Direct mode.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderExists() {
+ // Verify WorkloadId enum value exists with correct ID
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader).isNotNull();
+ assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC);
+ }
+
+ /**
+ * Verifies that the WorkloadId RNTBD header is defined as Byte token type,
+ * consistent with the ThroughputBucket pattern. The workload ID value (1-50)
+ * is encoded as a single byte on the wire.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsByteType() {
+ // Verify WorkloadId is Byte type (same as ThroughputBucket pattern)
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte);
+ }
+
+ /**
+ * Verifies that WorkloadId is not a required RNTBD header. The header is optional —
+ * requests without a workload ID are valid and should not be rejected by the SDK.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsNotRequired() {
+ // WorkloadId should not be a required header
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.isRequired()).isFalse();
+ }
+
+ /**
+ * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client
+ * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is
+ * excluded from this list and will be auto-encoded in the second pass of
+ * RntbdTokenStream.encode() along with other non-ordered headers.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdNotInThinClientOrderedList() {
+ // WorkloadId should NOT be in thinClientHeadersInOrderList
+ // It will be automatically encoded in the second pass of RntbdTokenStream.encode()
+ assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList)
+ .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId);
+ }
+
+ /**
+ * Verifies that valid workload ID values (1-50) can be parsed from String to int
+ * and cast to byte without data loss. Note: the SDK itself does not validate the
+ * range — this test confirms the encoding path works for expected values.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdValidValues() {
+ // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly
+ String[] validValues = {"1", "25", "50"};
+ for (String value : validValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ assertThat(byteVal).isBetween((byte) 1, (byte) 50);
+ }
+ }
+
+ /**
+ * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause
+ * exceptions in the SDK's parsing path. The SDK intentionally does not validate
+ * the range — invalid values are accepted and sent to the service, which silently
+ * ignores them.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdInvalidValuesAcceptedBySdk() {
+ // SDK does NOT validate range — service silently ignores invalid values
+ // These should not throw exceptions in SDK
+ String[] invalidValues = {"0", "51", "-1", "100"};
+ for (String value : invalidValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ // SDK accepts any integer value that fits in a byte
+ assertThat(byteVal).isNotNull();
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
new file mode 100644
index 000000000000..85b87090bb36
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -0,0 +1,327 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.rx;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosAsyncDatabase;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.TestObject;
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.implementation.TestConfigurations;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosBulkOperations;
+import com.azure.cosmos.models.CosmosContainerProperties;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosItemResponse;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.PartitionKey;
+import com.azure.cosmos.models.PartitionKeyDefinition;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * End-to-end integration tests for the custom headers / workload-id feature.
+ *
+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ * These tests create a real database and container, then execute CRUD and query operations
+ * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level.
+ *
+ * What is verified:
+ * 1. CRUD operations succeed with client-level custom headers (workload-id)
+ * 2. Per-request header overrides work via setHeader()
+ * 3. Client with no custom headers continues to work (no regression)
+ * 4. Query operations succeed with workload-id
+ * 5. Empty headers and multiple headers are handled correctly
+ *
+
+ */
+public class WorkloadIdE2ETests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ public WorkloadIdE2ETests() {
+ super(new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY));
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ clientWithWorkloadId = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(headers)
+ .buildAsyncClient();
+
+ database = createDatabase(clientWithWorkloadId, DATABASE_ID);
+
+ PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition();
+ ArrayList paths = new ArrayList<>();
+ paths.add("/mypk");
+ partitionKeyDef.setPaths(paths);
+ CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef);
+ database.createContainer(containerProperties).block();
+ container = database.getContainer(CONTAINER_ID);
+ }
+
+ /**
+ * Smoke test: verifies that a create (POST) operation succeeds when the client
+ * has a workload-id custom header set at the builder level. Confirms the header
+ * flows through the request pipeline without causing errors.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void createItemWithClientLevelWorkloadId() {
+ // Smoke test: verify create operation succeeds with client-level workload-id header
+ TestObject doc = TestObject.create();
+
+ CosmosItemResponse response = container
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ }
+
+ /**
+ * Verifies that a read (GET) operation succeeds with the client-level workload-id
+ * header and that the correct document is returned. Ensures the header does not
+ * interfere with normal read semantics.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void readItemWithClientLevelWorkloadId() {
+ // Verify read operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse response = container
+ .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class)
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ assertThat(response.getItem().getId()).isEqualTo(doc.getId());
+ }
+
+ /**
+ * Verifies that a replace (PUT) operation succeeds with the client-level workload-id
+ * header. Confirms the header propagates correctly for update operations.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void replaceItemWithClientLevelWorkloadId() {
+ // Verify replace operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ doc.setStringProp("updated-" + UUID.randomUUID());
+ CosmosItemResponse response = container
+ .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ }
+
+ /**
+ * Verifies that a delete operation succeeds with the client-level workload-id header
+ * and returns the expected 204 No Content status code.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void deleteItemWithClientLevelWorkloadId() {
+ // Verify delete operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse response = container
+ .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(204);
+ }
+
+ /**
+ * Verifies that a per-request workload-id header override via
+ * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header
+ * (value "30") should take precedence over the client-level default (value "15").
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void createItemWithRequestLevelWorkloadIdOverride() {
+ // Verify per-request header override works — request-level should take precedence
+ TestObject doc = TestObject.create();
+
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30");
+
+ CosmosItemResponse response = container
+ .createItem(doc, new PartitionKey(doc.getMypk()), options)
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ }
+
+ /**
+ * Verifies that a cross-partition query operation succeeds when the client has a
+ * workload-id custom header. Confirms the header flows correctly through the
+ * query pipeline and does not affect result correctness.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void queryItemsWithClientLevelWorkloadId() {
+ // Verify query operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions();
+ long count = container
+ .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class)
+ .collectList()
+ .block()
+ .size();
+
+ assertThat(count).isGreaterThanOrEqualTo(1);
+ }
+
+ /**
+ * Verifies that a per-request workload-id header override on
+ * {@code CosmosQueryRequestOptions.setHeader()} works for query operations.
+ * The request-level header (value "42") should take precedence over the
+ * client-level default.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void queryItemsWithRequestLevelWorkloadIdOverride() {
+ // Verify per-request header override on query options works
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions()
+ .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+
+ long count = container
+ .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class)
+ .collectList()
+ .block()
+ .size();
+
+ assertThat(count).isGreaterThanOrEqualTo(1);
+ }
+
+ /**
+ * Regression test: verifies that a client created without any custom headers
+ * continues to work normally. Ensures the custom headers feature does not
+ * introduce regressions for clients that do not use it.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithNoCustomHeadersStillWorks() {
+ // Verify that a client without custom headers works normally (no regression)
+ CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithoutHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithoutHeaders);
+ }
+ }
+
+ /**
+ * Verifies that a client created with an empty custom headers map works normally.
+ * An empty map should behave identically to no custom headers — no errors,
+ * no unexpected behavior.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithEmptyCustomHeaders() {
+ // Verify that a client with empty custom headers map works normally
+ CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(new HashMap<>())
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithEmptyHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithEmptyHeaders);
+ }
+ }
+
+ /**
+ * Verifies that a client can be configured with multiple custom headers simultaneously
+ * (workload-id plus an additional custom header). Confirms that all headers flow
+ * through the pipeline without interfering with each other.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithMultipleCustomHeaders() {
+ // Verify that multiple custom headers can be set simultaneously
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20");
+ headers.put("x-ms-custom-test-header", "test-value");
+
+ CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(headers)
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithMultipleHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithMultipleHeaders);
+ }
+ }
+
+ @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
+ public void afterClass() {
+ safeDeleteDatabase(database);
+ safeClose(clientWithWorkloadId);
+ }
+}
+
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
index ec0dd64af008..f54f44482db5 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
@@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable {
.withDefaultSerializer(this.defaultCustomSerializer)
.withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled())
.withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled())
+ .withCustomHeaders(builder.getCustomHeaders())
.build();
this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index 12d022e69ee7..aea282be566c 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -37,6 +37,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
@@ -155,6 +156,7 @@ public class CosmosClientBuilder implements
private boolean serverCertValidationDisabled = false;
private Function containerFactory = null;
+ private Map customHeaders;
/**
* Instantiates a new Cosmos client builder.
@@ -734,6 +736,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
return this;
}
+ /**
+ * Sets custom HTTP headers that will be included with every request from this client.
+ *
+ * These headers are sent with all requests. For Direct/RNTBD mode, only known headers
+ * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers
+ * work only in Gateway mode.
+ *
+ * If the same header is also set on request options (e.g.,
+ * {@code CosmosItemRequestOptions.setHeader(String, String)}),
+ * the request-level value takes precedence over the client-level value.
+ *
+ * @param customHeaders map of header name to value
+ * @return current CosmosClientBuilder
+ */
+ public CosmosClientBuilder customHeaders(Map customHeaders) {
+ this.customHeaders = customHeaders;
+ return this;
+ }
+
+ /**
+ * Gets the custom headers configured on this builder.
+ * @return the custom headers map, or null if not set
+ */
+ Map getCustomHeaders() {
+ return this.customHeaders;
+ }
+
/**
* Sets the retry policy options associated with the DocumentClient instance.
*
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
index 03590c1f8a5d..7953721019c5 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
@@ -116,6 +116,7 @@ class Builder {
private boolean isRegionScopedSessionCapturingEnabled;
private boolean isPerPartitionAutomaticFailoverEnabled;
private List operationPolicies;
+ private Map customHeaders;
public Builder withServiceEndpoint(String serviceEndpoint) {
try {
@@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu
return this;
}
+ public Builder withCustomHeaders(Map customHeaders) {
+ this.customHeaders = customHeaders;
+ return this;
+ }
+
private void ifThrowIllegalArgException(boolean value, String error) {
if (value) {
throw new IllegalArgumentException(error);
@@ -328,7 +334,8 @@ public AsyncDocumentClient build() {
defaultCustomSerializer,
isRegionScopedSessionCapturingEnabled,
operationPolicies,
- isPerPartitionAutomaticFailoverEnabled);
+ isPerPartitionAutomaticFailoverEnabled,
+ customHeaders);
client.init(state, null);
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
index 4e283defbc1d..32378ef0cc8d 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
@@ -298,6 +298,9 @@ public static class HttpHeaders {
// Region affinity headers
public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only";
+
+ // Workload ID header for Azure Monitor metrics attribution
+ public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id";
}
public static class A_IMHeaderValues {
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
index 192f8175978f..2f0bd4271d86 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
@@ -293,6 +293,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization
private final AtomicReference cachedCosmosAsyncClientSnapshot;
private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads;
private Consumer perPartitionFailoverConfigModifier;
+ private Map customHeaders;
public RxDocumentClientImpl(URI serviceEndpoint,
String masterKeyOrResourceToken,
@@ -366,6 +367,60 @@ public RxDocumentClientImpl(URI serviceEndpoint,
boolean isRegionScopedSessionCapturingEnabled,
List operationPolicies,
boolean isPerPartitionAutomaticFailoverEnabled) {
+ this(
+ serviceEndpoint,
+ masterKeyOrResourceToken,
+ permissionFeed,
+ connectionPolicy,
+ consistencyLevel,
+ readConsistencyStrategy,
+ configs,
+ cosmosAuthorizationTokenResolver,
+ credential,
+ tokenCredential,
+ sessionCapturingOverride,
+ connectionSharingAcrossClientsEnabled,
+ contentResponseOnWriteEnabled,
+ metadataCachesSnapshot,
+ apiType,
+ clientTelemetryConfig,
+ clientCorrelationId,
+ cosmosEndToEndOperationLatencyPolicyConfig,
+ sessionRetryOptions,
+ containerProactiveInitConfig,
+ defaultCustomSerializer,
+ isRegionScopedSessionCapturingEnabled,
+ operationPolicies,
+ isPerPartitionAutomaticFailoverEnabled,
+ null
+ );
+ }
+
+ public RxDocumentClientImpl(URI serviceEndpoint,
+ String masterKeyOrResourceToken,
+ List permissionFeed,
+ ConnectionPolicy connectionPolicy,
+ ConsistencyLevel consistencyLevel,
+ ReadConsistencyStrategy readConsistencyStrategy,
+ Configs configs,
+ CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver,
+ AzureKeyCredential credential,
+ TokenCredential tokenCredential,
+ boolean sessionCapturingOverride,
+ boolean connectionSharingAcrossClientsEnabled,
+ boolean contentResponseOnWriteEnabled,
+ CosmosClientMetadataCachesSnapshot metadataCachesSnapshot,
+ ApiType apiType,
+ CosmosClientTelemetryConfig clientTelemetryConfig,
+ String clientCorrelationId,
+ CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig,
+ SessionRetryOptions sessionRetryOptions,
+ CosmosContainerProactiveInitConfig containerProactiveInitConfig,
+ CosmosItemSerializer defaultCustomSerializer,
+ boolean isRegionScopedSessionCapturingEnabled,
+ List operationPolicies,
+ boolean isPerPartitionAutomaticFailoverEnabled,
+ Map customHeaders) {
this(
serviceEndpoint,
masterKeyOrResourceToken,
@@ -392,6 +447,7 @@ public RxDocumentClientImpl(URI serviceEndpoint,
this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver;
this.operationPolicies = operationPolicies;
+ this.customHeaders = customHeaders;
}
private RxDocumentClientImpl(URI serviceEndpoint,
@@ -1884,6 +1940,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten
private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) {
Map headers = new HashMap<>();
+ // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders())
+ if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
+ headers.putAll(this.customHeaders);
+ }
+
if (this.useMultipleWriteLocations) {
headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString());
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
index ba3ec8d2017d..d79231793679 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
@@ -598,7 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader {
PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false),
GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false),
ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false),
- HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false);
+ HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false),
+ WorkloadId((short)0x00DC, RntbdTokenType.Byte, false);
public static final List thinClientHeadersInOrderList = Arrays.asList(
EffectivePartitionKey,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
index 6f6e46ee695d..387e8cf3ed5a 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
@@ -23,6 +23,8 @@
import com.fasterxml.jackson.annotation.JsonFilter;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
@@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream {
// region Fields
+ private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class);
private static final String URL_TRIM = "/";
// endregion
@@ -134,6 +137,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream {
this.addGlobalDatabaseAccountName(headers);
this.addThroughputBucket(headers);
this.addHubRegionProcessingOnly(headers);
+ this.addWorkloadId(headers);
// Normal headers (Strings, Ints, Longs, etc.)
@@ -297,6 +301,8 @@ private RntbdToken getCorrelatedActivityId() {
private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); }
+ private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); }
+
private RntbdToken getGlobalDatabaseAccountName() {
return this.get(RntbdRequestHeader.GlobalDatabaseAccountName);
}
@@ -816,6 +822,19 @@ private void addHubRegionProcessingOnly(final Map headers) {
}
}
+ private void addWorkloadId(final Map headers) {
+ final String value = headers.get(HttpHeaders.WORKLOAD_ID);
+
+ if (StringUtils.isNotEmpty(value)) {
+ try {
+ final int workloadId = Integer.valueOf(value);
+ this.getWorkloadId().setValue((byte) workloadId);
+ } catch (NumberFormatException e) {
+ logger.warn("Invalid value for workload id header: {}", value, e);
+ }
+ }
+ }
+
private void addGlobalDatabaseAccountName(final Map headers)
{
final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME);
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
index 7d5a27324f95..3183fe59bdea 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
@@ -154,14 +154,17 @@ RequestOptions toRequestOptions() {
}
/**
- * Sets the custom batch request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosBatchRequestOptions.
*/
- CosmosBatchRequestOptions setHeader(String name, String value) {
+ public CosmosBatchRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
index f125c02d6725..cd688f8a0da6 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
@@ -257,13 +257,17 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat
}
/**
- * Sets the custom bulk request option value by key
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosBulkExecutionOptions.
*/
- CosmosBulkExecutionOptions setHeader(String name, String value) {
+ public CosmosBulkExecutionOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
index 3ac526de6d63..a1b675f2ffd8 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
@@ -564,14 +564,17 @@ public List getExcludedRegions() {
}
/**
- * Sets the custom change feed request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosChangeFeedRequestOptions.
*/
- CosmosChangeFeedRequestOptions setHeader(String name, String value) {
+ public CosmosChangeFeedRequestOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
index 72eb108a6428..fbc540e5baeb 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
@@ -566,14 +566,17 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre
}
/**
- * Sets the custom item request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
- *
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosItemRequestOptions.
*/
- CosmosItemRequestOptions setHeader(String name, String value) {
+ public CosmosItemRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
index 7ead6e208781..f0de81bbf823 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
@@ -260,6 +260,22 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions)
return this;
}
+ /**
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
+ * @return the CosmosQueryRequestOptions.
+ */
+ public CosmosQueryRequestOptions setHeader(String name, String value) {
+ this.actualRequestOptions.setHeader(name, value);
+ return this;
+ }
+
/**
* Gets the list of regions to exclude for the request/retries. These regions are excluded
* from the preferred region list.
From a75ab7b1bd494dfec12ad501b23923a101b1cba4 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 17:57:37 -0600
Subject: [PATCH 2/5] workload-id feature - initial commit(Spark)
---
.../cosmos/spark/CosmosClientCache.scala | 15 +-
.../spark/CosmosClientConfiguration.scala | 8 +-
.../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++-
.../cosmos/spark/CosmosClientCacheITest.scala | 17 +-
.../spark/CosmosClientConfigurationSpec.scala | 68 ++++++++
.../spark/CosmosPartitionPlannerSpec.scala | 24 ++-
.../cosmos/spark/PartitionMetadataSpec.scala | 48 ++++--
.../spark/SparkE2EWorkloadIdITest.scala | 150 ++++++++++++++++++
8 files changed, 323 insertions(+), 38 deletions(-)
create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
index e61a271aeb8b..9ad739d8f3f6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
@@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.function.BiPredicate
import scala.collection.concurrent.TrieMap
-
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
@@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}
+ // Apply custom HTTP headers (e.g., workload-id) to the builder if configured.
+ // These headers are attached to every Cosmos DB request made by this client instance.
+ if (cosmosClientConfiguration.customHeaders.isDefined) {
+ builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava)
+ }
+
var client = builder.buildAsyncClient()
if (cosmosClientConfiguration.clientInterceptors.isDefined) {
@@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Custom HTTP headers are part of the cache key because different workload-ids
+ // should produce different CosmosAsyncClient instances
+ customHeaders: Option[Map[String, String]]
)
private[this] object ClientConfigurationWrapper {
@@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientBuilderInterceptors,
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
- clientConfig.azureMonitorConfig
+ clientConfig.azureMonitorConfig,
+ clientConfig.customHeaders
)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
index 6f4e26e1f503..61fa0957af83 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
@@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration (
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Optional custom HTTP headers (e.g., workload-id) to attach to
+ // all Cosmos DB requests via CosmosClientBuilder.customHeaders()
+ customHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
@@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientBuilderInterceptors,
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
- diagnosticsConfig.azureMonitorConfig
+ diagnosticsConfig.azureMonitorConfig,
+ cosmosAccountConfig.customHeaders
)
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index eef3f6ae1f8d..6646b2e69ae5 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util
import java.util.{Locale, ServiceLoader}
+import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.{HashSet, List, Map}
import scala.collection.mutable
@@ -150,6 +151,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
+ // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
+ // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
+ // Flows through to CosmosClientBuilder.customHeaders().
+ val CustomHeaders = "spark.cosmos.customHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
@@ -295,7 +300,8 @@ private[spark] object CosmosConfigNames {
WriteOnRetryCommitInterceptor,
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
- WriteMaxRetryNoProgressIntervalInSeconds
+ WriteMaxRetryNoProgressIntervalInSeconds,
+ CustomHeaders
)
def validateConfigName(name: String): Unit = {
@@ -538,7 +544,10 @@ private case class CosmosAccountConfig(endpoint: String,
resourceGroupName: Option[String],
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
- clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ // Optional custom HTTP headers (e.g., workload-id) parsed from
+ // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder
+ customHeaders: Option[Map[String, String]]
)
private object CosmosAccountConfig extends BasicLoggingTrait {
@@ -719,6 +728,19 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN,
helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.")
+ // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
+ // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
+ // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache.
+ private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]](
+ key = CosmosConfigNames.CustomHeaders,
+ mandatory = false,
+ parseFromStringFunction = headersJson => {
+ val mapper = new com.fasterxml.jackson.databind.ObjectMapper()
+ val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
+ mapper.readValue(headersJson, typeRef).asScala.toMap
+ },
+ helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
+
private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
try {
@@ -753,6 +775,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
+ // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
+ val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig)
val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
@@ -864,7 +888,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
resourceGroupNameOpt,
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
- if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None })
+ if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
+ customHeaders)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
index ccf36791dc96..4d542c44612e 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
@@ -64,7 +64,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -91,7 +92,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -118,7 +120,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -145,7 +148,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
)
)
@@ -179,8 +183,9 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
- )
+ azureMonitorConfig = None,
+ customHeaders = None
+ )
logInfo(s"TestCase: {$testCaseName}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index 7fcc601ba016..a0627c0cf3dd 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -408,4 +408,72 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ")
configuration.azureMonitorConfig shouldEqual None
}
+
+ // Verifies that the spark.cosmos.customHeaders configuration option correctly parses
+ // a JSON string containing a single workload-id header into a Map[String, String] on
+ // CosmosClientConfiguration. This is the primary use case for the workload-id feature.
+ it should "parse customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
+ }
+
+ // Verifies that when spark.cosmos.customHeaders is not specified in the config map,
+ // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility —
+ // existing Spark jobs that don't set customHeaders continue to work without changes.
+ it should "handle missing customHeaders" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe None
+ }
+
+ // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing
+ // multiple headers into a Map with all entries preserved. This supports use cases where
+ // multiple custom headers need to be sent alongside workload-id.
+ it should "parse multiple custom headers" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get should have size 2
+ configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20"
+ configuration.customHeaders.get("x-custom-header") shouldEqual "value"
+ }
+
+ // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
+ // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
+ // and that no headers are injected when the JSON object is empty.
+ it should "handle empty customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> "{}"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get shouldBe empty
+ }
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
index 6ef90b55989d..ab73dc4e54d3 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
@@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
index dfd14c36c80f..65274bee2b19 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
@@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
new file mode 100644
index 000000000000..d9706d0709e5
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
@@ -0,0 +1,150 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.TestConfigurations
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.databind.node.ObjectNode
+
+import java.util.UUID
+
+/**
+ * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector.
+ *
+ * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows
+ * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that
+ * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations
+ * initiated via Spark DataFrames (reads and writes).
+ *
+ * Requires the Cosmos DB Emulator running
+ */
+class SparkE2EWorkloadIdITest
+ extends IntegrationSpec
+ with Spark
+ with CosmosClient
+ with AutoCleanableCosmosContainer
+ with BasicLoggingTrait {
+
+ val objectMapper = new ObjectMapper()
+
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off magic.number
+ //scalastyle:off null
+
+ // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The header should be passed through to the
+ // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior.
+ "spark query with customHeaders" can "read items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""",
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+
+ val item = rowsArray(0)
+ item.getAs[String]("id") shouldEqual id
+ }
+
+ // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The item is written via Spark and then verified
+ // via a direct SDK read to confirm the write was persisted correctly.
+ "spark write with customHeaders" can "write items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testWriteItem"
+ | }
+ |""".stripMargin
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""",
+ "spark.cosmos.write.strategy" -> "ItemOverwrite",
+ "spark.cosmos.write.bulk.enabled" -> "false",
+ "spark.cosmos.serialization.inclusionMode" -> "NonDefault"
+ )
+
+ val spark_session = spark
+ import spark_session.implicits._
+ val df = spark.read.json(Seq(rawItem).toDS())
+
+ df.write.format("cosmos.oltp").options(cfg).mode("Append").save()
+
+ // Verify the item was written by reading it back via the SDK directly
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block()
+ readItem.getItem.get("id").textValue() shouldEqual id
+ readItem.getItem.get("name").textValue() shouldEqual "testWriteItem"
+ }
+
+ // Regression test: verifies that Spark read operations continue to work correctly when
+ // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not
+ // break existing behavior for clients that do not use custom headers.
+ "spark operations without customHeaders" can "still succeed" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "noHeadersItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+ rowsArray(0).getAs[String]("id") shouldEqual id
+ }
+
+ //scalastyle:on magic.number
+ //scalastyle:on multiple.string.literals
+ //scalastyle:on null
+}
From 1bbf6b9bb962fec1f576fbf3a4b0863a0a132099 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 19:08:04 -0600
Subject: [PATCH 3/5] workload-id feature - cleaning up comments
---
.../com/azure/cosmos/rx/WorkloadIdE2ETests.java | 14 +-------------
1 file changed, 1 insertion(+), 13 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
index 85b87090bb36..a57b4d9d9a0b 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -32,17 +32,6 @@
* End-to-end integration tests for the custom headers / workload-id feature.
*
* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
- * These tests create a real database and container, then execute CRUD and query operations
- * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level.
- *
- * What is verified:
- * 1. CRUD operations succeed with client-level custom headers (workload-id)
- * 2. Per-request header overrides work via setHeader()
- * 3. Client with no custom headers continues to work (no regression)
- * 4. Query operations succeed with workload-id
- * 5. Empty headers and multiple headers are handled correctly
- *
-
*/
public class WorkloadIdE2ETests extends TestSuiteBase {
@@ -82,13 +71,12 @@ public void beforeClass() {
}
/**
- * Smoke test: verifies that a create (POST) operation succeeds when the client
+ * verifies that a create (POST) operation succeeds when the client
* has a workload-id custom header set at the builder level. Confirms the header
* flows through the request pipeline without causing errors.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void createItemWithClientLevelWorkloadId() {
- // Smoke test: verify create operation succeeds with client-level workload-id header
TestObject doc = TestObject.create();
CosmosItemResponse response = container
From 7aeb8b4e030958f77e14983fc5f0d3475333effb Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 19:20:36 -0600
Subject: [PATCH 4/5] workload-id feature - addressing copilot comments
---
sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 +
.../directconnectivity/rntbd/RntbdConstants.java | 4 ++--
.../directconnectivity/rntbd/RntbdRequestHeaders.java | 2 +-
8 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
index de77b69485a8..2eeccc69dada 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
index 80072357c58f..6e72286dfab6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
index b905a025a1e6..d0d80af466d6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
index e17cecdfdac2..ae444a9c399f 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
index e9be63ef89bd..ae910280495b 100644
--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index 8475b83dc5ee..ea4fdb82e1dc 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -4,6 +4,7 @@
#### Features Added
* Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757)
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
index d79231793679..d75bf5dc88e1 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
@@ -598,8 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader {
PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false),
GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false),
ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false),
- HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false),
- WorkloadId((short)0x00DC, RntbdTokenType.Byte, false);
+ WorkloadId((short)0x00DC, RntbdTokenType.Byte, false),
+ HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false);
public static final List thinClientHeadersInOrderList = Arrays.asList(
EffectivePartitionKey,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
index 387e8cf3ed5a..46f8060387fc 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
@@ -827,7 +827,7 @@ private void addWorkloadId(final Map headers) {
if (StringUtils.isNotEmpty(value)) {
try {
- final int workloadId = Integer.valueOf(value);
+ final int workloadId = Integer.parseInt(value);
this.getWorkloadId().setValue((byte) workloadId);
} catch (NumberFormatException e) {
logger.warn("Invalid value for workload id header: {}", value, e);
From 6d00d49320caa42432aa26725befe529e5934ad4 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Mon, 2 Mar 2026 16:38:14 -0600
Subject: [PATCH 5/5] workload-id feature - addressing comments
---
.../com/azure/cosmos/spark/CosmosConfig.scala | 13 +-
.../spark/CosmosClientConfigurationSpec.scala | 13 +-
.../com/azure/cosmos/CustomHeadersTests.java | 104 ++++++++++-
.../RxDocumentClientUnderTest.java | 7 +-
.../RxGatewayStoreModelTest.java | 172 +++++++++++++++++-
.../SpyClientUnderTestFactory.java | 7 +-
.../GatewayAddressCacheTest.java | 125 +++++++++++++
.../GlobalAddressResolverTest.java | 3 +-
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 62 ++-----
.../com/azure/cosmos/CosmosClientBuilder.java | 50 ++++-
.../implementation/RxDocumentClientImpl.java | 12 +-
.../implementation/RxGatewayStoreModel.java | 17 +-
.../implementation/ThinClientStoreModel.java | 3 +-
.../GatewayAddressCache.java | 17 +-
.../GlobalAddressResolver.java | 8 +-
.../models/CosmosReadManyRequestOptions.java | 16 ++
16 files changed, 546 insertions(+), 83 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index 6646b2e69ae5..62802d23b14c 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark}
import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants
import com.azure.cosmos.implementation.routing.LocationHelper
-import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings}
+import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils}
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
@@ -735,9 +735,14 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
key = CosmosConfigNames.CustomHeaders,
mandatory = false,
parseFromStringFunction = headersJson => {
- val mapper = new com.fasterxml.jackson.databind.ObjectMapper()
- val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
- mapper.readValue(headersJson, typeRef).asScala.toMap
+ try {
+ val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
+ Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap
+ } catch {
+ case e: Exception => throw new IllegalArgumentException(
+ s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " +
+ "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
+ }
},
helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index a0627c0cf3dd..377425189f07 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -441,10 +441,11 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.customHeaders shouldBe None
}
- // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing
- // multiple headers into a Map with all entries preserved. This supports use cases where
- // multiple custom headers need to be sent alongside workload-id.
- it should "parse multiple custom headers" in {
+ // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level.
+ // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD),
+ // unknown headers are silently dropped, so the allowlist ensures consistent behavior
+ // across Gateway and Direct modes.
+ it should "reject unknown custom headers" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
@@ -454,10 +455,10 @@ class CosmosClientConfigurationSpec extends UnitSpec {
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+ // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is.
+ // The allowlist validation happens later in CosmosClientBuilder.customHeaders()
configuration.customHeaders shouldBe defined
configuration.customHeaders.get should have size 2
- configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20"
- configuration.customHeaders.get("x-custom-header") shouldEqual "value"
}
// Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
index 3c95c8bd2687..19eb03744d1a 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
@@ -9,6 +9,7 @@
import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.CosmosReadManyRequestOptions;
import com.azure.cosmos.models.FeedRange;
import org.testng.annotations.Test;
@@ -16,6 +17,7 @@
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
@@ -73,23 +75,40 @@ public void customHeadersEmptyMapHandled() {
}
/**
- * Verifies that multiple custom headers can be set at once on the builder and
- * all entries are preserved and retrievable with correct keys and values.
+ * Verifies that headers not in the allowlist are rejected with IllegalArgumentException.
+ * This ensures consistent behavior across Gateway and Direct modes — only headers with
+ * RNTBD encoding support are allowed.
*/
@Test(groups = { "unit" })
- public void multipleCustomHeadersSupported() {
+ public void unknownHeaderRejectedByAllowlist() {
Map headers = new HashMap<>();
- headers.put("x-ms-cosmos-workload-id", "15");
headers.put("x-ms-custom-header", "value");
- CosmosClientBuilder builder = new CosmosClientBuilder()
+ assertThatThrownBy(() -> new CosmosClientBuilder()
.endpoint("https://test.documents.azure.com:443/")
.key("dGVzdEtleQ==")
- .customHeaders(headers);
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header")
+ .hasMessageContaining("not allowed");
+ }
+
+ /**
+ * Verifies that a map containing both an allowed header and a disallowed header
+ * is rejected — the entire map must pass the allowlist check.
+ */
+ @Test(groups = { "unit" })
+ public void mixedAllowedAndDisallowedHeadersRejected() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "15");
+ headers.put("x-ms-custom-header", "value");
- assertThat(builder.getCustomHeaders()).hasSize(2);
- assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15");
- assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value");
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header");
}
/**
@@ -158,6 +177,19 @@ public void setHeaderOnQueryRequestOptionsIsPublic() {
assertThat(options).isNotNull();
}
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on read-many operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnReadManyRequestOptionsIsPublic() {
+ CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "40");
+
+ assertThat(options).isNotNull();
+ }
+
/**
* Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
* with the correct canonical header name "x-ms-cosmos-workload-id" as expected
@@ -167,4 +199,58 @@ public void setHeaderOnQueryRequestOptionsIsPublic() {
public void workloadIdHttpHeaderConstant() {
assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
}
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at builder level with
+ * IllegalArgumentException. This covers both Gateway and Direct modes consistently
+ * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
+ * Range validation [1, 50] is the backend's responsibility — the SDK only validates
+ * that the value is a valid integer. This avoids hardcoding a range the backend team
+ * might change in the future.
+ */
+ @Test(groups = { "unit" })
+ public void outOfRangeWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+ }
+
+ /**
+ * Verifies that a valid workload-id value passes builder validation.
+ */
+ @Test(groups = { "unit" })
+ public void validWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+ }
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
index a9f5cb35549c..d5f8b92ac7a6 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
@@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import static org.mockito.Mockito.doAnswer;
@@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy(
GlobalEndpointManager globalEndpointManager,
GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker,
HttpClient rxOrigClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
origHttpClient = rxOrigClient;
spyHttpClient = Mockito.spy(rxOrigClient);
@@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy(
userAgentContainer,
globalEndpointManager,
spyHttpClient,
- apiType);
+ apiType,
+ customHeaders);
}
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
index 54440ecfabc5..587844f4043a 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
@@ -27,6 +27,8 @@
import java.net.SocketException;
import java.net.URI;
import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
@@ -102,6 +104,7 @@ public void readTimeout() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -205,7 +209,8 @@ public void applySessionToken(
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
@@ -277,7 +282,8 @@ public void validateApiType() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
@@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception {
return false;
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are injected into
+ * outgoing HTTP requests by performRequest(). This covers metadata requests
+ * (collection cache, partition key range) that don't go through getRequestHeaders().
+ */
+ @Test(groups = "unit")
+ public void customHeadersInjectedInPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25");
+ }
+
+ /**
+ * Verifies that request-level headers take precedence over client-level customHeaders.
+ * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest()
+ * should NOT overwrite it.
+ */
+ @Test(groups = "unit")
+ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col/docs/doc1",
+ ResourceType.Document);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ // Simulate request-level header already set (e.g., by getRequestHeaders())
+ dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // Request-level header "42" should win over client-level "10"
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, performRequest() still works normally
+ * without injecting any extra headers.
+ */
+ @Test(groups = "unit")
+ public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ null);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // No workload-id header should be present
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull();
+ }
+
enum SessionTokenType {
NONE, // no session token applied
USER, // userControlled session token
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
index b06d6f89b8e9..775b74785630 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
@@ -25,6 +25,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient rxClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.origRxGatewayStoreModel = super.createRxGatewayProxy(
sessionContainer,
consistencyLevel,
@@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
userAgentContainer,
globalEndpointManager,
rxClient,
- apiType);
+ apiType,
+ customHeaders);
this.requests = Collections.synchronizedList(new ArrayList<>());
this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel);
this.initRequestCapture();
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
index 172c00f799bc..9b938d0a1520 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
@@ -57,6 +57,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
@@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds,
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti
null,
null,
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh()
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask);
@@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
// connected status
@@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs
return new HttpClientUnderTestWrapper(origHttpClient);
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are included in
+ * GatewayAddressCache's defaultRequestHeaders, which are sent on every address
+ * resolution request.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.)
+ * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers
+ * set before customHeaders are preserved.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent");
+ customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version");
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // SDK headers should NOT be overwritten
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent");
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION);
+ // Custom header should still be added
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders
+ * contains only SDK system headers and no extra entries.
+ */
+ @Test(groups = { "unit" })
+ public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // Should only contain SDK system headers, no workload-id
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT);
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION);
+ assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID);
+ }
+
public String getNameBasedCollectionLink() {
return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId();
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
index 331be53cc7af..5879e7d3e61c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
@@ -110,7 +110,7 @@ public void resolveAsync() throws Exception {
GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider,
userAgentContainer,
- serviceConfigReader, connectionPolicy, null);
+ serviceConfigReader, connectionPolicy, null, null);
RxDocumentServiceRequest request;
request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(),
OperationType.Read,
@@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() {
userAgentContainer,
serviceConfigReader,
connectionPolicy,
+ null,
null);
GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache();
GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class);
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
index a57b4d9d9a0b..3bf2fdafce7c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -8,9 +8,6 @@
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.TestObject;
import com.azure.cosmos.implementation.HttpConstants;
-import com.azure.cosmos.implementation.TestConfigurations;
-import com.azure.cosmos.models.CosmosBulkExecutionOptions;
-import com.azure.cosmos.models.CosmosBulkOperations;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosItemResponse;
@@ -19,6 +16,7 @@
import com.azure.cosmos.models.PartitionKeyDefinition;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Factory;
import org.testng.annotations.Test;
import java.util.ArrayList;
@@ -32,6 +30,10 @@
* End-to-end integration tests for the custom headers / workload-id feature.
*
* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests
+ * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC),
+ * ensuring the workload-id header is correctly encoded and sent in both transport paths.
*/
public class WorkloadIdE2ETests extends TestSuiteBase {
@@ -42,10 +44,9 @@ public class WorkloadIdE2ETests extends TestSuiteBase {
private CosmosAsyncDatabase database;
private CosmosAsyncContainer container;
- public WorkloadIdE2ETests() {
- super(new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY));
+ @Factory(dataProvider = "simpleClientBuilderGatewaySession")
+ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
}
@BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
@@ -53,9 +54,7 @@ public void beforeClass() {
Map headers = new HashMap<>();
headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
- clientWithWorkloadId = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ clientWithWorkloadId = getClientBuilder()
.customHeaders(headers)
.buildAsyncClient();
@@ -218,9 +217,7 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() {
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void clientWithNoCustomHeadersStillWorks() {
// Verify that a client without custom headers works normally (no regression)
- CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder())
.buildAsyncClient();
try {
@@ -248,9 +245,7 @@ public void clientWithNoCustomHeadersStillWorks() {
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void clientWithEmptyCustomHeaders() {
// Verify that a client with empty custom headers map works normally
- CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder())
.customHeaders(new HashMap<>())
.buildAsyncClient();
@@ -272,38 +267,19 @@ public void clientWithEmptyCustomHeaders() {
}
/**
- * Verifies that a client can be configured with multiple custom headers simultaneously
- * (workload-id plus an additional custom header). Confirms that all headers flow
- * through the pipeline without interfering with each other.
+ * Verifies that unknown headers in customHeaders are rejected by the allowlist.
+ * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist
+ * ensures consistent behavior across Gateway and Direct modes.
*/
- @Test(groups = { "emulator" }, timeOut = TIMEOUT)
- public void clientWithMultipleCustomHeaders() {
- // Verify that multiple custom headers can be set simultaneously
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class)
+ public void unknownCustomHeadersRejectedByAllowlist() {
Map headers = new HashMap<>();
headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20");
headers.put("x-ms-custom-test-header", "test-value");
- CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
- .customHeaders(headers)
- .buildAsyncClient();
-
- try {
- CosmosAsyncContainer c = clientWithMultipleHeaders
- .getDatabase(DATABASE_ID)
- .getContainer(CONTAINER_ID);
-
- TestObject doc = TestObject.create();
- CosmosItemResponse response = c
- .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
- .block();
-
- assertThat(response).isNotNull();
- assertThat(response.getStatusCode()).isEqualTo(201);
- } finally {
- safeClose(clientWithMultipleHeaders);
- }
+ // Should throw IllegalArgumentException due to unknown header
+ copyCosmosClientBuilder(getClientBuilder())
+ .customHeaders(headers);
}
@AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index aea282be566c..e4cdf6ca1e30 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -13,6 +13,7 @@
import com.azure.cosmos.implementation.ConnectionPolicy;
import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot;
import com.azure.cosmos.implementation.DiagnosticsProvider;
+import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.Strings;
import com.azure.cosmos.implementation.WriteRetryPolicy;
import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList;
@@ -158,6 +159,20 @@ public class CosmosClientBuilder implements
private Function containerFactory = null;
private Map customHeaders;
+ /**
+ * Allowlist of headers permitted in {@link #customHeaders(Map)}.
+ *
+ * In Direct mode (RNTBD), only headers with explicit encoding support in
+ * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped.
+ * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header
+ * is allowed here, it works in both modes. To add a new allowed header, you must also add
+ * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry +
+ * {@code RntbdRequestHeaders.addXxx()} method).
+ */
+ private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet(
+ new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID))
+ );
+
/**
* Instantiates a new Cosmos client builder.
*/
@@ -739,9 +754,13 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
/**
* Sets custom HTTP headers that will be included with every request from this client.
*
- * These headers are sent with all requests. For Direct/RNTBD mode, only known headers
- * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers
- * work only in Gateway mode.
+ * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is
+ * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw
+ * {@link IllegalArgumentException}.
+ *
+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit
+ * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist
+ * ensures consistent behavior across both Gateway and Direct modes.
*
* If the same header is also set on request options (e.g.,
* {@code CosmosItemRequestOptions.setHeader(String, String)}),
@@ -749,8 +768,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
*
* @param customHeaders map of header name to value
* @return current CosmosClientBuilder
+ * @throws IllegalArgumentException if any header key is not in the allowlist, or if the
+ * workload-id value is not a valid integer
*/
public CosmosClientBuilder customHeaders(Map customHeaders) {
+ if (customHeaders != null) {
+ for (Map.Entry entry : customHeaders.entrySet()) {
+ String key = entry.getKey();
+ String value = entry.getValue();
+
+ if (!ALLOWED_CUSTOM_HEADERS.contains(key)) {
+ throw new IllegalArgumentException(
+ "Header '" + key + "' is not allowed in customHeaders. "
+ + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS);
+ }
+
+ // Validate workload-id value is a valid integer (range validation is left to the backend)
+ if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) {
+ try {
+ Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException(
+ "Invalid value '" + value + "' for header '" + key
+ + "'. The value must be a valid integer.", e);
+ }
+ }
+ }
+ }
this.customHeaders = customHeaders;
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
index 2f0bd4271d86..122542a8810a 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
@@ -863,7 +863,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func
this.userAgentContainer,
this.globalEndpointManager,
this.reactorHttpClient,
- this.apiType);
+ this.apiType,
+ this.customHeaders);
this.thinProxy = createThinProxy(this.sessionContainer,
this.consistencyLevel,
@@ -969,7 +970,8 @@ private void initializeDirectConnectivity() {
// this.gatewayConfigurationReader,
null,
this.connectionPolicy,
- this.apiType);
+ this.apiType,
+ this.customHeaders);
this.storeClientFactory = new StoreClientFactory(
this.addressResolver,
@@ -1013,7 +1015,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
return new RxGatewayStoreModel(
this,
sessionContainer,
@@ -1022,7 +1025,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
userAgentContainer,
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ customHeaders);
}
ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
index 979c528b32bb..a197723f0ae3 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
@@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize
private GatewayServiceConfigurationReader gatewayServiceConfigurationReader;
private RxClientCollectionCache collectionCache;
private GatewayServerErrorInjector gatewayServerErrorInjector;
+ private final Map customHeaders;
public RxGatewayStoreModel(
DiagnosticsClientContext clientContext,
@@ -100,7 +101,8 @@ public RxGatewayStoreModel(
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.clientContext = clientContext;
@@ -116,6 +118,7 @@ public RxGatewayStoreModel(
this.httpClient = httpClient;
this.sessionContainer = sessionContainer;
+ this.customHeaders = customHeaders;
}
public RxGatewayStoreModel(RxGatewayStoreModel inner) {
@@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) {
this.httpClient = inner.httpClient;
this.sessionContainer = inner.sessionContainer;
+ this.customHeaders = inner.customHeaders;
}
protected Map getDefaultHeaders(
@@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r
request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics();
}
+ // Apply client-level custom headers (e.g., workload-id) to all requests
+ // including metadata requests (collection cache, partition key range, etc.)
+ if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
+ for (Map.Entry entry : this.customHeaders.entrySet()) {
+ // Only set if not already present — request-level headers take precedence
+ if (!request.getHeaders().containsKey(entry.getKey())) {
+ request.getHeaders().put(entry.getKey(), entry.getValue());
+ }
+ }
+ }
+
URI uri = getUri(request);
request.requestContext.resourcePhysicalAddress = uri.toString();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
index d32e5d901f18..ff139e203d2e 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
@@ -56,7 +56,8 @@ public ThinClientStoreModel(
userAgentContainer,
globalEndpointManager,
httpClient,
- ApiType.SQL);
+ ApiType.SQL,
+ null);
String userAgent = userAgentContainer != null
? userAgentContainer.getUserAgent()
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
index e62d7b8c6ca4..7c761335b782 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
@@ -123,7 +123,8 @@ public GatewayAddressCache(
GlobalEndpointManager globalEndpointManager,
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
- GatewayServerErrorInjector gatewayServerErrorInjector) {
+ GatewayServerErrorInjector gatewayServerErrorInjector,
+ Map customHeaders) {
this.clientContext = clientContext;
try {
@@ -165,6 +166,14 @@ public GatewayAddressCache(
HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES,
HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES);
+ // Apply client-level custom headers (e.g., workload-id) to metadata requests
+ // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten
+ if (customHeaders != null && !customHeaders.isEmpty()) {
+ for (Map.Entry entry : customHeaders.entrySet()) {
+ this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue());
+ }
+ }
+
this.lastForcedRefreshMap = new ConcurrentHashMap<>();
this.globalEndpointManager = globalEndpointManager;
this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor;
@@ -188,7 +197,8 @@ public GatewayAddressCache(
GlobalEndpointManager globalEndpointManager,
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
- GatewayServerErrorInjector gatewayServerErrorInjector) {
+ GatewayServerErrorInjector gatewayServerErrorInjector,
+ Map customHeaders) {
this(clientContext,
serviceEndpoint,
protocol,
@@ -200,7 +210,8 @@ public GatewayAddressCache(
globalEndpointManager,
connectionPolicy,
proactiveOpenConnectionsProcessor,
- gatewayServerErrorInjector);
+ gatewayServerErrorInjector,
+ customHeaders);
}
@Override
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
index 00905682b4d1..2fd5287da028 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
@@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver {
private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor;
private ConnectionPolicy connectionPolicy;
private GatewayServerErrorInjector gatewayServerErrorInjector;
+ private final Map customHeaders;
public GlobalAddressResolver(
DiagnosticsClientContext diagnosticsClientContext,
@@ -74,7 +75,8 @@ public GlobalAddressResolver(
UserAgentContainer userAgentContainer,
GatewayServiceConfigurationReader serviceConfigReader,
ConnectionPolicy connectionPolicy,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.diagnosticsClientContext = diagnosticsClientContext;
this.httpClient = httpClient;
this.endpointManager = endpointManager;
@@ -86,6 +88,7 @@ public GlobalAddressResolver(
this.serviceConfigReader = serviceConfigReader;
this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled();
this.connectionPolicy = connectionPolicy;
+ this.customHeaders = customHeaders;
int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0;
this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover)
@@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) {
this.endpointManager,
this.connectionPolicy,
this.proactiveOpenConnectionsProcessor,
- this.gatewayServerErrorInjector);
+ this.gatewayServerErrorInjector,
+ this.customHeaders);
AddressResolver addressResolver = new AddressResolver();
addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache);
EndpointCache cache = new EndpointCache();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
index f6e570258042..de2d769f789b 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
@@ -366,6 +366,22 @@ public Set getKeywordIdentifiers() {
return this.actualRequestOptions.getKeywordIdentifiers();
}
+ /**
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
+ * @return the CosmosReadManyRequestOptions.
+ */
+ public CosmosReadManyRequestOptions setHeader(String name, String value) {
+ this.actualRequestOptions.setHeader(name, value);
+ return this;
+ }
+
CosmosQueryRequestOptionsBase> getImpl() {
return this.actualRequestOptions;
}