diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index a66d7b9b7282..5c00937013ad 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 3edf40a6b496..705334e4c3a1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index be5982831836..97bff869bb59 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index cf62ab372904..d922e4a579d3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 3b87ef08c3a0..f31024628cb7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import java.util.function.BiPredicate import scala.collection.concurrent.TrieMap - // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import @@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + // Apply custom HTTP headers (e.g., workload-id) to the builder if configured. + // These headers are attached to every Cosmos DB request made by this client instance. + if (cosmosClientConfiguration.customHeaders.isDefined) { + builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava) + } + var client = builder.buildAsyncClient() if (cosmosClientConfiguration.clientInterceptors.isDefined) { @@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Custom HTTP headers are part of the cache key because different workload-ids + // should produce different CosmosAsyncClient instances + customHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientBuilderInterceptors, clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, - clientConfig.azureMonitorConfig + clientConfig.azureMonitorConfig, + clientConfig.customHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 6f4e26e1f503..61fa0957af83 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration ( clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Optional custom HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.customHeaders() + customHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientBuilderInterceptors, cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, - diagnosticsConfig.azureMonitorConfig + diagnosticsConfig.azureMonitorConfig, + cosmosAccountConfig.customHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..928b0cd09445 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper -import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} +import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition} import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter import java.time.{Duration, Instant} import java.util import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable @@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" + // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} + // Flows through to CosmosClientBuilder.customHeaders(). + val CustomHeaders = "spark.cosmos.customHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames { WriteOnRetryCommitInterceptor, WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, - WriteMaxRetryNoProgressIntervalInSeconds + WriteMaxRetryNoProgressIntervalInSeconds, + CustomHeaders ) def validateConfigName(name: String): Unit = { @@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String, resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], - clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + // Optional custom HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder + customHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -727,6 +736,24 @@ private object CosmosAccountConfig extends BasicLoggingTrait { parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like + // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. + // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache. + private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.CustomHeaders, + mandatory = false, + parseFromStringFunction = headersJson => { + try { + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + } catch { + case e: Exception => throw new IllegalArgumentException( + s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " + + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) + } + }, + helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -761,6 +788,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) + // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -880,7 +909,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { resourceGroupNameOpt, azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, - if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, + customHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index ccf36791dc96..4d542c44612e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -64,7 +64,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -91,7 +92,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -118,7 +120,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -145,7 +148,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ) ) @@ -179,8 +183,9 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None - ) + azureMonitorConfig = None, + customHeaders = None + ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 7fcc601ba016..377425189f07 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -408,4 +408,73 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ") configuration.azureMonitorConfig shouldEqual None } + + // Verifies that the spark.cosmos.customHeaders configuration option correctly parses + // a JSON string containing a single workload-id header into a Map[String, String] on + // CosmosClientConfiguration. This is the primary use case for the workload-id feature. + it should "parse customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + } + + // Verifies that when spark.cosmos.customHeaders is not specified in the config map, + // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set customHeaders continue to work without changes. + it should "handle missing customHeaders" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe None + } + + // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level. + // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD), + // unknown headers are silently dropped, so the allowlist ensures consistent behavior + // across Gateway and Direct modes. + it should "reject unknown custom headers" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is. + // The allowlist validation happens later in CosmosClientBuilder.customHeaders() + configuration.customHeaders shouldBe defined + configuration.customHeaders.get should have size 2 + } + + // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, + // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases + // and that no headers are injected when the JSON object is empty. + it should "handle empty customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> "{}" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get shouldBe empty + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 6ef90b55989d..ab73dc4e54d3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index dfd14c36c80f..65274bee2b19 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala new file mode 100644 index 000000000000..d9706d0709e5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +/** + * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector. + * + * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows + * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that + * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations + * initiated via Spark DataFrames (reads and writes). + * + * Requires the Cosmos DB Emulator running + */ +class SparkE2EWorkloadIdITest + extends IntegrationSpec + with Spark + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait { + + val objectMapper = new ObjectMapper() + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The header should be passed through to the + // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior. + "spark query with customHeaders" can "read items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } + + // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The item is written via Spark and then verified + // via a direct SDK read to confirm the write was persisted correctly. + "spark write with customHeaders" can "write items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testWriteItem" + | } + |""".stripMargin + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.write.strategy" -> "ItemOverwrite", + "spark.cosmos.write.bulk.enabled" -> "false", + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + val spark_session = spark + import spark_session.implicits._ + val df = spark.read.json(Seq(rawItem).toDS()) + + df.write.format("cosmos.oltp").options(cfg).mode("Append").save() + + // Verify the item was written by reading it back via the SDK directly + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block() + readItem.getItem.get("id").textValue() shouldEqual id + readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" + } + + // Regression test: verifies that Spark read operations continue to work correctly when + // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not + // break existing behavior for clients that do not use custom headers. + "spark operations without customHeaders" can "still succeed" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "noHeadersItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + rowsArray(0).getAs[String]("id") shouldEqual id + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals + //scalastyle:on null +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index a3f77dee2dbe..799d86414ff6 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java new file mode 100644 index 000000000000..19eb03744d1a --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. + *

+ * These tests verify the public API surface: builder fluent methods, getter behavior, + * null/empty handling, and that setHeader() is publicly accessible on all request options classes. + */ +public class CustomHeadersTests { + + /** + * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders() + * are stored correctly and retrievable via getCustomHeaders(). + */ + @Test(groups = { "unit" }) + public void customHeadersSetOnBuilder() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25"); + } + + /** + * Verifies that passing null to customHeaders() does not throw and that + * getCustomHeaders() returns null, ensuring graceful null handling. + */ + @Test(groups = { "unit" }) + public void customHeadersNullHandledGracefully() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(null); + + assertThat(builder.getCustomHeaders()).isNull(); + } + + /** + * Verifies that passing an empty map to customHeaders() is accepted and + * getCustomHeaders() returns an empty (not null) map. + */ + @Test(groups = { "unit" }) + public void customHeadersEmptyMapHandled() { + Map emptyHeaders = new HashMap<>(); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(emptyHeaders); + + assertThat(builder.getCustomHeaders()).isEmpty(); + } + + /** + * Verifies that headers not in the allowlist are rejected with IllegalArgumentException. + * This ensures consistent behavior across Gateway and Direct modes — only headers with + * RNTBD encoding support are allowed. + */ + @Test(groups = { "unit" }) + public void unknownHeaderRejectedByAllowlist() { + Map headers = new HashMap<>(); + headers.put("x-ms-custom-header", "value"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header") + .hasMessageContaining("not allowed"); + } + + /** + * Verifies that a map containing both an allowed header and a disallowed header + * is rejected — the entire map must pass the allowlist check. + */ + @Test(groups = { "unit" }) + public void mixedAllowedAndDisallowedHeadersRejected() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "15"); + headers.put("x-ms-custom-header", "value"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header"); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on CRUD operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnItemRequestOptionsIsPublic() { + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "15"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on batch operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBatchRequestOptionsIsPublic() { + CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "20"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on change feed operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnChangeFeedRequestOptionsIsPublic() { + CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setHeader("x-ms-cosmos-workload-id", "25"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on bulk ingestion operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBulkExecutionOptionsIsPublic() { + CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() + .setHeader("x-ms-cosmos-workload-id", "30"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on query operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnQueryRequestOptionsIsPublic() { + CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "35"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on read-many operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnReadManyRequestOptionsIsPublic() { + CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "40"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined + * with the correct canonical header name "x-ms-cosmos-workload-id" as expected + * by the Cosmos DB service. + */ + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that a non-numeric workload-id value is rejected at builder level with + * IllegalArgumentException. This covers both Gateway and Direct modes consistently + * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. + * Range validation [1, 50] is the backend's responsibility — the SDK only validates + * that the value is a valid integer. This avoids hardcoding a range the backend team + * might change in the future. + */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + } + + /** + * Verifies that a valid workload-id value passes builder validation. + */ + @Test(groups = { "unit" }) + public void validWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java index a9f5cb35549c..d5f8b92ac7a6 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.mockito.Mockito.doAnswer; @@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy( GlobalEndpointManager globalEndpointManager, GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker, HttpClient rxOrigClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { origHttpClient = rxOrigClient; spyHttpClient = Mockito.spy(rxOrigClient); @@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy( userAgentContainer, globalEndpointManager, spyHttpClient, - apiType); + apiType, + customHeaders); } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java index 54440ecfabc5..587844f4043a 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java @@ -27,6 +27,8 @@ import java.net.SocketException; import java.net.URI; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext; @@ -102,6 +104,7 @@ public void readTimeout() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -205,7 +209,8 @@ public void applySessionToken( new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( @@ -277,7 +282,8 @@ public void validateApiType() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception { return false; } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are injected into + * outgoing HTTP requests by performRequest(). This covers metadata requests + * (collection cache, partition key range) that don't go through getRequestHeaders(). + */ + @Test(groups = "unit") + public void customHeadersInjectedInPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25"); + } + + /** + * Verifies that request-level headers take precedence over client-level customHeaders. + * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest() + * should NOT overwrite it. + */ + @Test(groups = "unit") + public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col/docs/doc1", + ResourceType.Document); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + // Simulate request-level header already set (e.g., by getRequestHeaders()) + dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // Request-level header "42" should win over client-level "10" + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42"); + } + + /** + * Verifies that when customHeaders is null, performRequest() still works normally + * without injecting any extra headers. + */ + @Test(groups = "unit") + public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // No workload-id header should be present + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull(); + } + enum SessionTokenType { NONE, // no session token applied USER, // userControlled session token diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java index b06d6f89b8e9..775b74785630 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.stream.Collectors; @@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient rxClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.origRxGatewayStoreModel = super.createRxGatewayProxy( sessionContainer, consistencyLevel, @@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, rxClient, - apiType); + apiType, + customHeaders); this.requests = Collections.synchronizedList(new ArrayList<>()); this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel); this.initRequestCapture(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java index 570c385c6d17..71d78aefdb2d 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java @@ -57,6 +57,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds, null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti null, null, null, + null, null); RxDocumentServiceRequest req = @@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh() null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask); @@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); // connected status @@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs return new HttpClientUnderTestWrapper(origHttpClient); } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are included in + * GatewayAddressCache's defaultRequestHeaders, which are sent on every address + * resolution request. + */ + @Test(groups = { "unit" }) + public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) + * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers + * set before customHeaders are preserved. + */ + @Test(groups = { "unit" }) + public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); + customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // SDK headers should NOT be overwritten + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent"); + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION); + // Custom header should still be added + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders + * contains only SDK system headers and no extra entries. + */ + @Test(groups = { "unit" }) + public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + null); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // Should only contain SDK system headers, no workload-id + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT); + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION); + assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID); + } + public String getNameBasedCollectionLink() { return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId(); } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java index 331be53cc7af..5879e7d3e61c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java @@ -110,7 +110,7 @@ public void resolveAsync() throws Exception { GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider, userAgentContainer, - serviceConfigReader, connectionPolicy, null); + serviceConfigReader, connectionPolicy, null, null); RxDocumentServiceRequest request; request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(), OperationType.Read, @@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() { userAgentContainer, serviceConfigReader, connectionPolicy, + null, null); GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache(); GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java new file mode 100644 index 000000000000..9ca123e16160 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.implementation.directconnectivity.rntbd; + +import com.azure.cosmos.implementation.HttpConstants; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants. + *

+ * + * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC), + * correct token type (Byte), is not required, and is not in the thin-client ordered header list + * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()). + */ +public class RntbdWorkloadIdTests { + + /** + * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders + * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and + * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping. + */ + @Test(groups = { "unit" }) + public void workloadIdConstantExists() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader + * with the correct wire ID (0x00DC). This ID is used to identify the header in the + * binary RNTBD protocol when communicating in Direct mode. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderExists() { + // Verify WorkloadId enum value exists with correct ID + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader).isNotNull(); + assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC); + } + + /** + * Verifies that the WorkloadId RNTBD header is defined as Byte token type, + * consistent with the ThroughputBucket pattern. The workload ID value (1-50) + * is encoded as a single byte on the wire. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsByteType() { + // Verify WorkloadId is Byte type (same as ThroughputBucket pattern) + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte); + } + + /** + * Verifies that WorkloadId is not a required RNTBD header. The header is optional — + * requests without a workload ID are valid and should not be rejected by the SDK. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsNotRequired() { + // WorkloadId should not be a required header + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.isRequired()).isFalse(); + } + + /** + * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client + * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is + * excluded from this list and will be auto-encoded in the second pass of + * RntbdTokenStream.encode() along with other non-ordered headers. + */ + @Test(groups = { "unit" }) + public void workloadIdNotInThinClientOrderedList() { + // WorkloadId should NOT be in thinClientHeadersInOrderList + // It will be automatically encoded in the second pass of RntbdTokenStream.encode() + assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList) + .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId); + } + + /** + * Verifies that valid workload ID values (1-50) can be parsed from String to int + * and cast to byte without data loss. Note: the SDK itself does not validate the + * range — this test confirms the encoding path works for expected values. + */ + @Test(groups = { "unit" }) + public void workloadIdValidValues() { + // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly + String[] validValues = {"1", "25", "50"}; + for (String value : validValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + assertThat(byteVal).isBetween((byte) 1, (byte) 50); + } + } + + /** + * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause + * exceptions in the SDK's parsing path. The SDK intentionally does not validate + * the range — invalid values are accepted and sent to the service, which silently + * ignores them. + */ + @Test(groups = { "unit" }) + public void workloadIdInvalidValuesAcceptedBySdk() { + // SDK does NOT validate range — service silently ignores invalid values + // These should not throw exceptions in SDK + String[] invalidValues = {"0", "51", "-1", "100"}; + for (String value : invalidValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + // SDK accepts any integer value that fits in a byte + assertThat(byteVal).isNotNull(); + } + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java new file mode 100644 index 000000000000..3bf2fdafce7c --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosItemResponse; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * End-to-end integration tests for the custom headers / workload-id feature. + *

+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests + * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC), + * ensuring the workload-id header is correctly encoded and sent in both transport paths. + */ +public class WorkloadIdE2ETests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + @Factory(dataProvider = "simpleClientBuilderGatewaySession") + public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + clientWithWorkloadId = getClientBuilder() + .customHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * verifies that a create (POST) operation succeeds when the client + * has a workload-id custom header set at the builder level. Confirms the header + * flows through the request pipeline without causing errors. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithClientLevelWorkloadId() { + TestObject doc = TestObject.create(); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a read (GET) operation succeeds with the client-level workload-id + * header and that the correct document is returned. Ensures the header does not + * interfere with normal read semantics. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void readItemWithClientLevelWorkloadId() { + // Verify read operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getItem().getId()).isEqualTo(doc.getId()); + } + + /** + * Verifies that a replace (PUT) operation succeeds with the client-level workload-id + * header. Confirms the header propagates correctly for update operations. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void replaceItemWithClientLevelWorkloadId() { + // Verify replace operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + doc.setStringProp("updated-" + UUID.randomUUID()); + CosmosItemResponse response = container + .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + } + + /** + * Verifies that a delete operation succeeds with the client-level workload-id header + * and returns the expected 204 No Content status code. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void deleteItemWithClientLevelWorkloadId() { + // Verify delete operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(204); + } + + /** + * Verifies that a per-request workload-id header override via + * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header + * (value "30") should take precedence over the client-level default (value "15"). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override works — request-level should take precedence + TestObject doc = TestObject.create(); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30"); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), options) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a cross-partition query operation succeeds when the client has a + * workload-id custom header. Confirms the header flows correctly through the + * query pipeline and does not affect result correctness. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithClientLevelWorkloadId() { + // Verify query operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions(); + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Verifies that a per-request workload-id header override on + * {@code CosmosQueryRequestOptions.setHeader()} works for query operations. + * The request-level header (value "42") should take precedence over the + * client-level default. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override on query options works + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Regression test: verifies that a client created without any custom headers + * continues to work normally. Ensures the custom headers feature does not + * introduce regressions for clients that do not use it. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithNoCustomHeadersStillWorks() { + // Verify that a client without custom headers works normally (no regression) + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithoutHeaders); + } + } + + /** + * Verifies that a client created with an empty custom headers map works normally. + * An empty map should behave identically to no custom headers — no errors, + * no unexpected behavior. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithEmptyCustomHeaders() { + // Verify that a client with empty custom headers map works normally + CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder()) + .customHeaders(new HashMap<>()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithEmptyHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithEmptyHeaders); + } + } + + /** + * Verifies that unknown headers in customHeaders are rejected by the allowlist. + * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist + * ensures consistent behavior across Gateway and Direct modes. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void unknownCustomHeadersRejectedByAllowlist() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); + headers.put("x-ms-custom-test-header", "test-value"); + + // Should throw IllegalArgumentException due to unknown header + copyCosmosClientBuilder(getClientBuilder()) + .customHeaders(headers); + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index cd5510ae024c..c080bbc9c5c5 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -4,6 +4,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java index ec0dd64af008..f54f44482db5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java @@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable { .withDefaultSerializer(this.defaultCustomSerializer) .withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled()) .withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled()) + .withCustomHeaders(builder.getCustomHeaders()) .build(); this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 12d022e69ee7..e4cdf6ca1e30 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -13,6 +13,7 @@ import com.azure.cosmos.implementation.ConnectionPolicy; import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot; import com.azure.cosmos.implementation.DiagnosticsProvider; +import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.implementation.Strings; import com.azure.cosmos.implementation.WriteRetryPolicy; import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList; @@ -37,6 +38,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; @@ -155,6 +157,21 @@ public class CosmosClientBuilder implements private boolean serverCertValidationDisabled = false; private Function containerFactory = null; + private Map customHeaders; + + /** + * Allowlist of headers permitted in {@link #customHeaders(Map)}. + *

+ * In Direct mode (RNTBD), only headers with explicit encoding support in + * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped. + * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header + * is allowed here, it works in both modes. To add a new allowed header, you must also add + * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry + + * {@code RntbdRequestHeaders.addXxx()} method). + */ + private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet( + new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID)) + ); /** * Instantiates a new Cosmos client builder. @@ -734,6 +751,62 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { return this; } + /** + * Sets custom HTTP headers that will be included with every request from this client. + *

+ * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is + * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw + * {@link IllegalArgumentException}. + *

+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit + * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist + * ensures consistent behavior across both Gateway and Direct modes. + *

+ * If the same header is also set on request options (e.g., + * {@code CosmosItemRequestOptions.setHeader(String, String)}), + * the request-level value takes precedence over the client-level value. + * + * @param customHeaders map of header name to value + * @return current CosmosClientBuilder + * @throws IllegalArgumentException if any header key is not in the allowlist, or if the + * workload-id value is not a valid integer + */ + public CosmosClientBuilder customHeaders(Map customHeaders) { + if (customHeaders != null) { + for (Map.Entry entry : customHeaders.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + + if (!ALLOWED_CUSTOM_HEADERS.contains(key)) { + throw new IllegalArgumentException( + "Header '" + key + "' is not allowed in customHeaders. " + + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS); + } + + // Validate workload-id value is a valid integer (range validation is left to the backend) + if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) { + try { + Integer.parseInt(value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + value + "' for header '" + key + + "'. The value must be a valid integer.", e); + } + } + } + } + this.customHeaders = customHeaders; + return this; + } + + /** + * Gets the custom headers configured on this builder. + * @return the custom headers map, or null if not set + */ + Map getCustomHeaders() { + return this.customHeaders; + } + /** * Sets the retry policy options associated with the DocumentClient instance. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 03590c1f8a5d..7953721019c5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -116,6 +116,7 @@ class Builder { private boolean isRegionScopedSessionCapturingEnabled; private boolean isPerPartitionAutomaticFailoverEnabled; private List operationPolicies; + private Map customHeaders; public Builder withServiceEndpoint(String serviceEndpoint) { try { @@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu return this; } + public Builder withCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + private void ifThrowIllegalArgException(boolean value, String error) { if (value) { throw new IllegalArgumentException(error); @@ -328,7 +334,8 @@ public AsyncDocumentClient build() { defaultCustomSerializer, isRegionScopedSessionCapturingEnabled, operationPolicies, - isPerPartitionAutomaticFailoverEnabled); + isPerPartitionAutomaticFailoverEnabled, + customHeaders); client.init(state, null); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java index 4e283defbc1d..32378ef0cc8d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java @@ -298,6 +298,9 @@ public static class HttpHeaders { // Region affinity headers public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only"; + + // Workload ID header for Azure Monitor metrics attribution + public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id"; } public static class A_IMHeaderValues { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 725e22c3a253..c28400e86973 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -294,6 +294,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization private final AtomicReference cachedCosmosAsyncClientSnapshot; private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads; private Consumer perPartitionFailoverConfigModifier; + private Map customHeaders; public RxDocumentClientImpl(URI serviceEndpoint, String masterKeyOrResourceToken, @@ -367,6 +368,60 @@ public RxDocumentClientImpl(URI serviceEndpoint, boolean isRegionScopedSessionCapturingEnabled, List operationPolicies, boolean isPerPartitionAutomaticFailoverEnabled) { + this( + serviceEndpoint, + masterKeyOrResourceToken, + permissionFeed, + connectionPolicy, + consistencyLevel, + readConsistencyStrategy, + configs, + cosmosAuthorizationTokenResolver, + credential, + tokenCredential, + sessionCapturingOverride, + connectionSharingAcrossClientsEnabled, + contentResponseOnWriteEnabled, + metadataCachesSnapshot, + apiType, + clientTelemetryConfig, + clientCorrelationId, + cosmosEndToEndOperationLatencyPolicyConfig, + sessionRetryOptions, + containerProactiveInitConfig, + defaultCustomSerializer, + isRegionScopedSessionCapturingEnabled, + operationPolicies, + isPerPartitionAutomaticFailoverEnabled, + null + ); + } + + public RxDocumentClientImpl(URI serviceEndpoint, + String masterKeyOrResourceToken, + List permissionFeed, + ConnectionPolicy connectionPolicy, + ConsistencyLevel consistencyLevel, + ReadConsistencyStrategy readConsistencyStrategy, + Configs configs, + CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver, + AzureKeyCredential credential, + TokenCredential tokenCredential, + boolean sessionCapturingOverride, + boolean connectionSharingAcrossClientsEnabled, + boolean contentResponseOnWriteEnabled, + CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, + ApiType apiType, + CosmosClientTelemetryConfig clientTelemetryConfig, + String clientCorrelationId, + CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig, + SessionRetryOptions sessionRetryOptions, + CosmosContainerProactiveInitConfig containerProactiveInitConfig, + CosmosItemSerializer defaultCustomSerializer, + boolean isRegionScopedSessionCapturingEnabled, + List operationPolicies, + boolean isPerPartitionAutomaticFailoverEnabled, + Map customHeaders) { this( serviceEndpoint, masterKeyOrResourceToken, @@ -393,6 +448,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver; this.operationPolicies = operationPolicies; + this.customHeaders = customHeaders; } private RxDocumentClientImpl(URI serviceEndpoint, @@ -808,7 +864,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.userAgentContainer, this.globalEndpointManager, this.reactorHttpClient, - this.apiType); + this.apiType, + this.customHeaders); this.thinProxy = createThinProxy(this.sessionContainer, this.consistencyLevel, @@ -925,7 +982,8 @@ private void initializeDirectConnectivity() { // this.gatewayConfigurationReader, null, this.connectionPolicy, - this.apiType); + this.apiType, + this.customHeaders); this.storeClientFactory = new StoreClientFactory( this.addressResolver, @@ -969,7 +1027,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { return new RxGatewayStoreModel( this, sessionContainer, @@ -978,7 +1037,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, httpClient, - apiType); + apiType, + customHeaders); } ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, @@ -1896,6 +1956,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) { Map headers = new HashMap<>(); + // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders()) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + headers.putAll(this.customHeaders); + } + if (this.useMultipleWriteLocations) { headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString()); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java index 1d5a5ea260dc..35f6c64c0079 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java @@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize private GatewayServiceConfigurationReader gatewayServiceConfigurationReader; private RxClientCollectionCache collectionCache; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public RxGatewayStoreModel( DiagnosticsClientContext clientContext, @@ -100,7 +101,8 @@ public RxGatewayStoreModel( UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.clientContext = clientContext; @@ -116,6 +118,7 @@ public RxGatewayStoreModel( this.httpClient = httpClient; this.sessionContainer = sessionContainer; + this.customHeaders = customHeaders; } public RxGatewayStoreModel(RxGatewayStoreModel inner) { @@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) { this.httpClient = inner.httpClient; this.sessionContainer = inner.sessionContainer; + this.customHeaders = inner.customHeaders; } protected Map getDefaultHeaders( @@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics(); } + // Apply client-level custom headers (e.g., workload-id) to all requests + // including metadata requests (collection cache, partition key range, etc.) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + for (Map.Entry entry : this.customHeaders.entrySet()) { + // Only set if not already present — request-level headers take precedence + if (!request.getHeaders().containsKey(entry.getKey())) { + request.getHeaders().put(entry.getKey(), entry.getValue()); + } + } + } + URI uri = getUri(request); request.requestContext.resourcePhysicalAddress = uri.toString(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java index d32e5d901f18..ff139e203d2e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java @@ -56,7 +56,8 @@ public ThinClientStoreModel( userAgentContainer, globalEndpointManager, httpClient, - ApiType.SQL); + ApiType.SQL, + null); String userAgent = userAgentContainer != null ? userAgentContainer.getUserAgent() diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index e62d7b8c6ca4..7c761335b782 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -123,7 +123,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this.clientContext = clientContext; try { @@ -165,6 +166,14 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); + // Apply client-level custom headers (e.g., workload-id) to metadata requests + // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten + if (customHeaders != null && !customHeaders.isEmpty()) { + for (Map.Entry entry : customHeaders.entrySet()) { + this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + this.lastForcedRefreshMap = new ConcurrentHashMap<>(); this.globalEndpointManager = globalEndpointManager; this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor; @@ -188,7 +197,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this(clientContext, serviceEndpoint, protocol, @@ -200,7 +210,8 @@ public GatewayAddressCache( globalEndpointManager, connectionPolicy, proactiveOpenConnectionsProcessor, - gatewayServerErrorInjector); + gatewayServerErrorInjector, + customHeaders); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java index 00905682b4d1..2fd5287da028 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java @@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver { private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor; private ConnectionPolicy connectionPolicy; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public GlobalAddressResolver( DiagnosticsClientContext diagnosticsClientContext, @@ -74,7 +75,8 @@ public GlobalAddressResolver( UserAgentContainer userAgentContainer, GatewayServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.diagnosticsClientContext = diagnosticsClientContext; this.httpClient = httpClient; this.endpointManager = endpointManager; @@ -86,6 +88,7 @@ public GlobalAddressResolver( this.serviceConfigReader = serviceConfigReader; this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled(); this.connectionPolicy = connectionPolicy; + this.customHeaders = customHeaders; int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0; this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover) @@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) { this.endpointManager, this.connectionPolicy, this.proactiveOpenConnectionsProcessor, - this.gatewayServerErrorInjector); + this.gatewayServerErrorInjector, + this.customHeaders); AddressResolver addressResolver = new AddressResolver(); addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache); EndpointCache cache = new EndpointCache(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index ba3ec8d2017d..d75bf5dc88e1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -598,6 +598,7 @@ public enum RntbdRequestHeader implements RntbdHeader { PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false), HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); public static final List thinClientHeadersInOrderList = Arrays.asList( diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index 6f6e46ee695d..46f8060387fc 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonFilter; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { // region Fields + private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class); private static final String URL_TRIM = "/"; // endregion @@ -134,6 +137,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { this.addGlobalDatabaseAccountName(headers); this.addThroughputBucket(headers); this.addHubRegionProcessingOnly(headers); + this.addWorkloadId(headers); // Normal headers (Strings, Ints, Longs, etc.) @@ -297,6 +301,8 @@ private RntbdToken getCorrelatedActivityId() { private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); } + private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); } + private RntbdToken getGlobalDatabaseAccountName() { return this.get(RntbdRequestHeader.GlobalDatabaseAccountName); } @@ -816,6 +822,19 @@ private void addHubRegionProcessingOnly(final Map headers) { } } + private void addWorkloadId(final Map headers) { + final String value = headers.get(HttpHeaders.WORKLOAD_ID); + + if (StringUtils.isNotEmpty(value)) { + try { + final int workloadId = Integer.parseInt(value); + this.getWorkloadId().setValue((byte) workloadId); + } catch (NumberFormatException e) { + logger.warn("Invalid value for workload id header: {}", value, e); + } + } + } + private void addGlobalDatabaseAccountName(final Map headers) { final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index 7d5a27324f95..3183fe59bdea 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -154,14 +154,17 @@ RequestOptions toRequestOptions() { } /** - * Sets the custom batch request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBatchRequestOptions. */ - CosmosBatchRequestOptions setHeader(String name, String value) { + public CosmosBatchRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index f125c02d6725..cd688f8a0da6 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -257,13 +257,17 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat } /** - * Sets the custom bulk request option value by key + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBulkExecutionOptions. */ - CosmosBulkExecutionOptions setHeader(String name, String value) { + public CosmosBulkExecutionOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index 3ac526de6d63..a1b675f2ffd8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -564,14 +564,17 @@ public List getExcludedRegions() { } /** - * Sets the custom change feed request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosChangeFeedRequestOptions. */ - CosmosChangeFeedRequestOptions setHeader(String name, String value) { + public CosmosChangeFeedRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index 72eb108a6428..fbc540e5baeb 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -566,14 +566,17 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre } /** - * Sets the custom item request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value - * + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosItemRequestOptions. */ - CosmosItemRequestOptions setHeader(String name, String value) { + public CosmosItemRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index 7ead6e208781..f0de81bbf823 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -260,6 +260,22 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) return this; } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosQueryRequestOptions. + */ + public CosmosQueryRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + /** * Gets the list of regions to exclude for the request/retries. These regions are excluded * from the preferred region list. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index f6e570258042..de2d769f789b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -366,6 +366,22 @@ public Set getKeywordIdentifiers() { return this.actualRequestOptions.getKeywordIdentifiers(); } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosReadManyRequestOptions. + */ + public CosmosReadManyRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + CosmosQueryRequestOptionsBase getImpl() { return this.actualRequestOptions; }