diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
index a66d7b9b7282..5c00937013ad 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
index 3edf40a6b496..705334e4c3a1 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
index be5982831836..97bff869bb59 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
index cf62ab372904..d922e4a579d3 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
index 3b87ef08c3a0..f31024628cb7 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
@@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.function.BiPredicate
import scala.collection.concurrent.TrieMap
-
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
@@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}
+ // Apply custom HTTP headers (e.g., workload-id) to the builder if configured.
+ // These headers are attached to every Cosmos DB request made by this client instance.
+ if (cosmosClientConfiguration.customHeaders.isDefined) {
+ builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava)
+ }
+
var client = builder.buildAsyncClient()
if (cosmosClientConfiguration.clientInterceptors.isDefined) {
@@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Custom HTTP headers are part of the cache key because different workload-ids
+ // should produce different CosmosAsyncClient instances
+ customHeaders: Option[Map[String, String]]
)
private[this] object ClientConfigurationWrapper {
@@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientBuilderInterceptors,
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
- clientConfig.azureMonitorConfig
+ clientConfig.azureMonitorConfig,
+ clientConfig.customHeaders
)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
index 6f4e26e1f503..61fa0957af83 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
@@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration (
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Optional custom HTTP headers (e.g., workload-id) to attach to
+ // all Cosmos DB requests via CosmosClientBuilder.customHeaders()
+ customHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
@@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientBuilderInterceptors,
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
- diagnosticsConfig.azureMonitorConfig
+ diagnosticsConfig.azureMonitorConfig,
+ cosmosAccountConfig.customHeaders
)
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index 951f4735444d..928b0cd09445 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark}
import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants
import com.azure.cosmos.implementation.routing.LocationHelper
-import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings}
+import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils}
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
@@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util
import java.util.{Locale, ServiceLoader}
+import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.{HashSet, List, Map}
import scala.collection.mutable
@@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
+ // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
+ // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
+ // Flows through to CosmosClientBuilder.customHeaders().
+ val CustomHeaders = "spark.cosmos.customHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
@@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames {
WriteOnRetryCommitInterceptor,
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
- WriteMaxRetryNoProgressIntervalInSeconds
+ WriteMaxRetryNoProgressIntervalInSeconds,
+ CustomHeaders
)
def validateConfigName(name: String): Unit = {
@@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String,
resourceGroupName: Option[String],
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
- clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ // Optional custom HTTP headers (e.g., workload-id) parsed from
+ // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder
+ customHeaders: Option[Map[String, String]]
)
private object CosmosAccountConfig extends BasicLoggingTrait {
@@ -727,6 +736,24 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN,
helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.")
+ // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
+ // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
+ // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache.
+ private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]](
+ key = CosmosConfigNames.CustomHeaders,
+ mandatory = false,
+ parseFromStringFunction = headersJson => {
+ try {
+ val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
+ Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap
+ } catch {
+ case e: Exception => throw new IllegalArgumentException(
+ s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " +
+ "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
+ }
+ },
+ helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
+
private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
try {
@@ -761,6 +788,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
+ // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
+ val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig)
val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
@@ -880,7 +909,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
resourceGroupNameOpt,
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
- if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None })
+ if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
+ customHeaders)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
index ccf36791dc96..4d542c44612e 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
@@ -64,7 +64,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -91,7 +92,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -118,7 +120,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -145,7 +148,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
)
)
@@ -179,8 +183,9 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
- )
+ azureMonitorConfig = None,
+ customHeaders = None
+ )
logInfo(s"TestCase: {$testCaseName}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index 7fcc601ba016..377425189f07 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -408,4 +408,73 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ")
configuration.azureMonitorConfig shouldEqual None
}
+
+ // Verifies that the spark.cosmos.customHeaders configuration option correctly parses
+ // a JSON string containing a single workload-id header into a Map[String, String] on
+ // CosmosClientConfiguration. This is the primary use case for the workload-id feature.
+ it should "parse customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
+ }
+
+ // Verifies that when spark.cosmos.customHeaders is not specified in the config map,
+ // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility —
+ // existing Spark jobs that don't set customHeaders continue to work without changes.
+ it should "handle missing customHeaders" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe None
+ }
+
+ // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level.
+ // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD),
+ // unknown headers are silently dropped, so the allowlist ensures consistent behavior
+ // across Gateway and Direct modes.
+ it should "reject unknown custom headers" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is.
+ // The allowlist validation happens later in CosmosClientBuilder.customHeaders()
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get should have size 2
+ }
+
+ // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
+ // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
+ // and that no headers are injected when the JSON object is empty.
+ it should "handle empty customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> "{}"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get shouldBe empty
+ }
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
index 6ef90b55989d..ab73dc4e54d3 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
@@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
index dfd14c36c80f..65274bee2b19 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
@@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
new file mode 100644
index 000000000000..d9706d0709e5
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
@@ -0,0 +1,150 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.TestConfigurations
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.databind.node.ObjectNode
+
+import java.util.UUID
+
+/**
+ * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector.
+ *
+ * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows
+ * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that
+ * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations
+ * initiated via Spark DataFrames (reads and writes).
+ *
+ * Requires the Cosmos DB Emulator running
+ */
+class SparkE2EWorkloadIdITest
+ extends IntegrationSpec
+ with Spark
+ with CosmosClient
+ with AutoCleanableCosmosContainer
+ with BasicLoggingTrait {
+
+ val objectMapper = new ObjectMapper()
+
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off magic.number
+ //scalastyle:off null
+
+ // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The header should be passed through to the
+ // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior.
+ "spark query with customHeaders" can "read items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""",
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+
+ val item = rowsArray(0)
+ item.getAs[String]("id") shouldEqual id
+ }
+
+ // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The item is written via Spark and then verified
+ // via a direct SDK read to confirm the write was persisted correctly.
+ "spark write with customHeaders" can "write items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testWriteItem"
+ | }
+ |""".stripMargin
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""",
+ "spark.cosmos.write.strategy" -> "ItemOverwrite",
+ "spark.cosmos.write.bulk.enabled" -> "false",
+ "spark.cosmos.serialization.inclusionMode" -> "NonDefault"
+ )
+
+ val spark_session = spark
+ import spark_session.implicits._
+ val df = spark.read.json(Seq(rawItem).toDS())
+
+ df.write.format("cosmos.oltp").options(cfg).mode("Append").save()
+
+ // Verify the item was written by reading it back via the SDK directly
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block()
+ readItem.getItem.get("id").textValue() shouldEqual id
+ readItem.getItem.get("name").textValue() shouldEqual "testWriteItem"
+ }
+
+ // Regression test: verifies that Spark read operations continue to work correctly when
+ // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not
+ // break existing behavior for clients that do not use custom headers.
+ "spark operations without customHeaders" can "still succeed" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "noHeadersItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+ rowsArray(0).getAs[String]("id") shouldEqual id
+ }
+
+ //scalastyle:on magic.number
+ //scalastyle:on multiple.string.literals
+ //scalastyle:on null
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
index a3f77dee2dbe..799d86414ff6 100644
--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
new file mode 100644
index 000000000000..19eb03744d1a
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
@@ -0,0 +1,256 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.models.CosmosBatchRequestOptions;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.CosmosReadManyRequestOptions;
+import com.azure.cosmos.models.FeedRange;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
+ *
+ * These tests verify the public API surface: builder fluent methods, getter behavior,
+ * null/empty handling, and that setHeader() is publicly accessible on all request options classes.
+ */
+public class CustomHeadersTests {
+
+ /**
+ * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders()
+ * are stored correctly and retrievable via getCustomHeaders().
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersSetOnBuilder() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "25");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25");
+ }
+
+ /**
+ * Verifies that passing null to customHeaders() does not throw and that
+ * getCustomHeaders() returns null, ensuring graceful null handling.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersNullHandledGracefully() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(null);
+
+ assertThat(builder.getCustomHeaders()).isNull();
+ }
+
+ /**
+ * Verifies that passing an empty map to customHeaders() is accepted and
+ * getCustomHeaders() returns an empty (not null) map.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersEmptyMapHandled() {
+ Map emptyHeaders = new HashMap<>();
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(emptyHeaders);
+
+ assertThat(builder.getCustomHeaders()).isEmpty();
+ }
+
+ /**
+ * Verifies that headers not in the allowlist are rejected with IllegalArgumentException.
+ * This ensures consistent behavior across Gateway and Direct modes — only headers with
+ * RNTBD encoding support are allowed.
+ */
+ @Test(groups = { "unit" })
+ public void unknownHeaderRejectedByAllowlist() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-custom-header", "value");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header")
+ .hasMessageContaining("not allowed");
+ }
+
+ /**
+ * Verifies that a map containing both an allowed header and a disallowed header
+ * is rejected — the entire map must pass the allowlist check.
+ */
+ @Test(groups = { "unit" })
+ public void mixedAllowedAndDisallowedHeadersRejected() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "15");
+ headers.put("x-ms-custom-header", "value");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header");
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on CRUD operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnItemRequestOptionsIsPublic() {
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "15");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on batch operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBatchRequestOptionsIsPublic() {
+ CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "20");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on change feed operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnChangeFeedRequestOptionsIsPublic() {
+ CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
+ .createForProcessingFromBeginning(FeedRange.forFullRange())
+ .setHeader("x-ms-cosmos-workload-id", "25");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on bulk ingestion operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBulkExecutionOptionsIsPublic() {
+ CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
+ .setHeader("x-ms-cosmos-workload-id", "30");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on query operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnQueryRequestOptionsIsPublic() {
+ CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "35");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on read-many operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnReadManyRequestOptionsIsPublic() {
+ CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "40");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
+ * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
+ * by the Cosmos DB service.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdHttpHeaderConstant() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at builder level with
+ * IllegalArgumentException. This covers both Gateway and Direct modes consistently
+ * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
+ * Range validation [1, 50] is the backend's responsibility — the SDK only validates
+ * that the value is a valid integer. This avoids hardcoding a range the backend team
+ * might change in the future.
+ */
+ @Test(groups = { "unit" })
+ public void outOfRangeWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+ }
+
+ /**
+ * Verifies that a valid workload-id value passes builder validation.
+ */
+ @Test(groups = { "unit" })
+ public void validWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
index a9f5cb35549c..d5f8b92ac7a6 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
@@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import static org.mockito.Mockito.doAnswer;
@@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy(
GlobalEndpointManager globalEndpointManager,
GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker,
HttpClient rxOrigClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
origHttpClient = rxOrigClient;
spyHttpClient = Mockito.spy(rxOrigClient);
@@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy(
userAgentContainer,
globalEndpointManager,
spyHttpClient,
- apiType);
+ apiType,
+ customHeaders);
}
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
index 54440ecfabc5..587844f4043a 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
@@ -27,6 +27,8 @@
import java.net.SocketException;
import java.net.URI;
import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
@@ -102,6 +104,7 @@ public void readTimeout() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -205,7 +209,8 @@ public void applySessionToken(
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
@@ -277,7 +282,8 @@ public void validateApiType() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
@@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception {
return false;
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are injected into
+ * outgoing HTTP requests by performRequest(). This covers metadata requests
+ * (collection cache, partition key range) that don't go through getRequestHeaders().
+ */
+ @Test(groups = "unit")
+ public void customHeadersInjectedInPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25");
+ }
+
+ /**
+ * Verifies that request-level headers take precedence over client-level customHeaders.
+ * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest()
+ * should NOT overwrite it.
+ */
+ @Test(groups = "unit")
+ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col/docs/doc1",
+ ResourceType.Document);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ // Simulate request-level header already set (e.g., by getRequestHeaders())
+ dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // Request-level header "42" should win over client-level "10"
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, performRequest() still works normally
+ * without injecting any extra headers.
+ */
+ @Test(groups = "unit")
+ public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ null);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // No workload-id header should be present
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull();
+ }
+
enum SessionTokenType {
NONE, // no session token applied
USER, // userControlled session token
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
index b06d6f89b8e9..775b74785630 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
@@ -25,6 +25,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient rxClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.origRxGatewayStoreModel = super.createRxGatewayProxy(
sessionContainer,
consistencyLevel,
@@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
userAgentContainer,
globalEndpointManager,
rxClient,
- apiType);
+ apiType,
+ customHeaders);
this.requests = Collections.synchronizedList(new ArrayList<>());
this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel);
this.initRequestCapture();
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
index 570c385c6d17..71d78aefdb2d 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
@@ -57,6 +57,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
@@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds,
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti
null,
null,
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh()
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask);
@@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
// connected status
@@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs
return new HttpClientUnderTestWrapper(origHttpClient);
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are included in
+ * GatewayAddressCache's defaultRequestHeaders, which are sent on every address
+ * resolution request.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.)
+ * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers
+ * set before customHeaders are preserved.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent");
+ customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version");
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // SDK headers should NOT be overwritten
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent");
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION);
+ // Custom header should still be added
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders
+ * contains only SDK system headers and no extra entries.
+ */
+ @Test(groups = { "unit" })
+ public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // Should only contain SDK system headers, no workload-id
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT);
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION);
+ assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID);
+ }
+
public String getNameBasedCollectionLink() {
return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId();
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
index 331be53cc7af..5879e7d3e61c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
@@ -110,7 +110,7 @@ public void resolveAsync() throws Exception {
GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider,
userAgentContainer,
- serviceConfigReader, connectionPolicy, null);
+ serviceConfigReader, connectionPolicy, null, null);
RxDocumentServiceRequest request;
request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(),
OperationType.Read,
@@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() {
userAgentContainer,
serviceConfigReader,
connectionPolicy,
+ null,
null);
GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache();
GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class);
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
new file mode 100644
index 000000000000..9ca123e16160
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.implementation.directconnectivity.rntbd;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import org.testng.annotations.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants.
+ *
+ *
+ * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC),
+ * correct token type (Byte), is not required, and is not in the thin-client ordered header list
+ * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()).
+ */
+public class RntbdWorkloadIdTests {
+
+ /**
+ * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders
+ * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and
+ * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdConstantExists() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader
+ * with the correct wire ID (0x00DC). This ID is used to identify the header in the
+ * binary RNTBD protocol when communicating in Direct mode.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderExists() {
+ // Verify WorkloadId enum value exists with correct ID
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader).isNotNull();
+ assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC);
+ }
+
+ /**
+ * Verifies that the WorkloadId RNTBD header is defined as Byte token type,
+ * consistent with the ThroughputBucket pattern. The workload ID value (1-50)
+ * is encoded as a single byte on the wire.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsByteType() {
+ // Verify WorkloadId is Byte type (same as ThroughputBucket pattern)
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte);
+ }
+
+ /**
+ * Verifies that WorkloadId is not a required RNTBD header. The header is optional —
+ * requests without a workload ID are valid and should not be rejected by the SDK.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsNotRequired() {
+ // WorkloadId should not be a required header
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.isRequired()).isFalse();
+ }
+
+ /**
+ * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client
+ * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is
+ * excluded from this list and will be auto-encoded in the second pass of
+ * RntbdTokenStream.encode() along with other non-ordered headers.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdNotInThinClientOrderedList() {
+ // WorkloadId should NOT be in thinClientHeadersInOrderList
+ // It will be automatically encoded in the second pass of RntbdTokenStream.encode()
+ assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList)
+ .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId);
+ }
+
+ /**
+ * Verifies that valid workload ID values (1-50) can be parsed from String to int
+ * and cast to byte without data loss. Note: the SDK itself does not validate the
+ * range — this test confirms the encoding path works for expected values.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdValidValues() {
+ // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly
+ String[] validValues = {"1", "25", "50"};
+ for (String value : validValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ assertThat(byteVal).isBetween((byte) 1, (byte) 50);
+ }
+ }
+
+ /**
+ * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause
+ * exceptions in the SDK's parsing path. The SDK intentionally does not validate
+ * the range — invalid values are accepted and sent to the service, which silently
+ * ignores them.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdInvalidValuesAcceptedBySdk() {
+ // SDK does NOT validate range — service silently ignores invalid values
+ // These should not throw exceptions in SDK
+ String[] invalidValues = {"0", "51", "-1", "100"};
+ for (String value : invalidValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ // SDK accepts any integer value that fits in a byte
+ assertThat(byteVal).isNotNull();
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
new file mode 100644
index 000000000000..3bf2fdafce7c
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -0,0 +1,291 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.rx;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosAsyncDatabase;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.TestObject;
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.models.CosmosContainerProperties;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosItemResponse;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.PartitionKey;
+import com.azure.cosmos.models.PartitionKeyDefinition;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Factory;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * End-to-end integration tests for the custom headers / workload-id feature.
+ *
+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests
+ * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC),
+ * ensuring the workload-id header is correctly encoded and sent in both transport paths.
+ */
+public class WorkloadIdE2ETests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ @Factory(dataProvider = "simpleClientBuilderGatewaySession")
+ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ clientWithWorkloadId = getClientBuilder()
+ .customHeaders(headers)
+ .buildAsyncClient();
+
+ database = createDatabase(clientWithWorkloadId, DATABASE_ID);
+
+ PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition();
+ ArrayList paths = new ArrayList<>();
+ paths.add("/mypk");
+ partitionKeyDef.setPaths(paths);
+ CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef);
+ database.createContainer(containerProperties).block();
+ container = database.getContainer(CONTAINER_ID);
+ }
+
+ /**
+ * verifies that a create (POST) operation succeeds when the client
+ * has a workload-id custom header set at the builder level. Confirms the header
+ * flows through the request pipeline without causing errors.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void createItemWithClientLevelWorkloadId() {
+ TestObject doc = TestObject.create();
+
+ CosmosItemResponse response = container
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ }
+
+ /**
+ * Verifies that a read (GET) operation succeeds with the client-level workload-id
+ * header and that the correct document is returned. Ensures the header does not
+ * interfere with normal read semantics.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void readItemWithClientLevelWorkloadId() {
+ // Verify read operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse response = container
+ .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class)
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ assertThat(response.getItem().getId()).isEqualTo(doc.getId());
+ }
+
+ /**
+ * Verifies that a replace (PUT) operation succeeds with the client-level workload-id
+ * header. Confirms the header propagates correctly for update operations.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void replaceItemWithClientLevelWorkloadId() {
+ // Verify replace operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ doc.setStringProp("updated-" + UUID.randomUUID());
+ CosmosItemResponse response = container
+ .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ }
+
+ /**
+ * Verifies that a delete operation succeeds with the client-level workload-id header
+ * and returns the expected 204 No Content status code.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void deleteItemWithClientLevelWorkloadId() {
+ // Verify delete operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse