Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)

#### Features Added
* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: x-ms-cosmos-workload-id is an RNTBD header too.


#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)

#### Features Added
* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)

#### Features Added
* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.45.0-beta.1 (Unreleased)

#### Features Added
* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.function.BiPredicate
import scala.collection.concurrent.TrieMap

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
Expand Down Expand Up @@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}

// Apply custom HTTP headers (e.g., workload-id) to the builder if configured.
// These headers are attached to every Cosmos DB request made by this client instance.
if (cosmosClientConfiguration.customHeaders.isDefined) {
builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava)
}

var client = builder.buildAsyncClient()

if (cosmosClientConfiguration.clientInterceptors.isDefined) {
Expand Down Expand Up @@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig]
azureMonitorConfig: Option[AzureMonitorConfig],
// Custom HTTP headers are part of the cache key because different workload-ids
// should produce different CosmosAsyncClient instances
customHeaders: Option[Map[String, String]]
)

private[this] object ClientConfigurationWrapper {
Expand All @@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientBuilderInterceptors,
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
clientConfig.azureMonitorConfig
clientConfig.azureMonitorConfig,
clientConfig.customHeaders
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration (
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig]
azureMonitorConfig: Option[AzureMonitorConfig],
// Optional custom HTTP headers (e.g., workload-id) to attach to
// all Cosmos DB requests via CosmosClientBuilder.customHeaders()
customHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
Expand Down Expand Up @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientBuilderInterceptors,
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
diagnosticsConfig.azureMonitorConfig
diagnosticsConfig.azureMonitorConfig,
cosmosAccountConfig.customHeaders
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark}
import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants
import com.azure.cosmos.implementation.routing.LocationHelper
import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings}
import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils}
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
Expand All @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util
import java.util.{Locale, ServiceLoader}
import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.{HashSet, List, Map}
import scala.collection.mutable
Expand Down Expand Up @@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
// Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
// Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
// Flows through to CosmosClientBuilder.customHeaders().
val CustomHeaders = "spark.cosmos.customHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
Expand Down Expand Up @@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames {
WriteOnRetryCommitInterceptor,
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
WriteMaxRetryNoProgressIntervalInSeconds
WriteMaxRetryNoProgressIntervalInSeconds,
CustomHeaders
)

def validateConfigName(name: String): Unit = {
Expand Down Expand Up @@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String,
resourceGroupName: Option[String],
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
// Optional custom HTTP headers (e.g., workload-id) parsed from
// spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder
customHeaders: Option[Map[String, String]]
)

private object CosmosAccountConfig extends BasicLoggingTrait {
Expand Down Expand Up @@ -727,6 +736,24 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN,
helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.")

// Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
// {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
// These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache.
private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]](
key = CosmosConfigNames.CustomHeaders,
mandatory = false,
parseFromStringFunction = headersJson => {
try {
val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap
} catch {
case e: Exception => throw new IllegalArgumentException(
s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " +
"Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
}
},
helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")

private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
try {
Expand Down Expand Up @@ -761,6 +788,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
// Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig)

val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
Expand Down Expand Up @@ -880,7 +909,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
resourceGroupNameOpt,
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None })
if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
customHeaders)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)
),
(
Expand All @@ -91,7 +92,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)
),
(
Expand All @@ -118,7 +120,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)
),
(
Expand All @@ -145,7 +148,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)
)
)
Expand Down Expand Up @@ -179,8 +183,9 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
)
azureMonitorConfig = None,
customHeaders = None
)

logInfo(s"TestCase: {$testCaseName}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,73 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ")
configuration.azureMonitorConfig shouldEqual None
}

// Verifies that the spark.cosmos.customHeaders configuration option correctly parses
// a JSON string containing a single workload-id header into a Map[String, String] on
// CosmosClientConfiguration. This is the primary use case for the workload-id feature.
it should "parse customHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.customHeaders shouldBe defined
configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
}

// Verifies that when spark.cosmos.customHeaders is not specified in the config map,
// CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility —
// existing Spark jobs that don't set customHeaders continue to work without changes.
it should "handle missing customHeaders" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz"
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.customHeaders shouldBe None
}

// Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level.
// Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD),
// unknown headers are silently dropped, so the allowlist ensures consistent behavior
// across Gateway and Direct modes.
it should "reject unknown custom headers" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

// Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is.
// The allowlist validation happens later in CosmosClientBuilder.customHeaders()
configuration.customHeaders shouldBe defined
configuration.customHeaders.get should have size 2
}

// Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
// resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
// and that no headers are injected when the JSON object is empty.
it should "handle empty customHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.customHeaders" -> "{}"
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.customHeaders shouldBe defined
configuration.customHeaders.get shouldBe empty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down Expand Up @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
customHeaders = None
)

val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
Expand Down
Loading
Loading