Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/Core/Resolvers/MsSqlQueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Net;
using System.Security.Claims;
using System.Text;
Expand Down Expand Up @@ -520,6 +521,49 @@ public override string GetSessionParamsQuery(HttpContext? httpContext, IDictiona
sessionMapQuery = sessionMapQuery.Append(statementToSetReadOnlyParam);
}

// Add OpenTelemetry correlation values for observability.
// These allow correlating database queries with distributed traces.
Activity? currentActivity = Activity.Current;
if (currentActivity is not null)
{
string traceIdParamName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(traceIdParamName, new(currentActivity.TraceId.ToString()));
sessionMapQuery.Append($"EXEC sp_set_session_context 'dab.trace_id', {traceIdParamName}, @read_only = 0;");

string spanIdParamName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(spanIdParamName, new(currentActivity.SpanId.ToString()));
sessionMapQuery.Append($"EXEC sp_set_session_context 'dab.span_id', {spanIdParamName}, @read_only = 0;");
}

// Add OBO-specific observability values when user-delegated auth is enabled.
// These values are for observability/auditing only and MUST NOT be used for authorization decisions.
// For row-level security, use database policies with the user's actual claims.
if (_dataSourceUserDelegatedAuth.ContainsKey(dataSourceName))
{
// Set auth type indicator for OBO requests
string authTypeParamName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(authTypeParamName, new("obo"));
sessionMapQuery.Append($"EXEC sp_set_session_context 'dab.auth_type', {authTypeParamName}, @read_only = 0;");

// Set user identifier (oid preferred, fallback to sub) for auditing/observability
string? userId = httpContext.User.FindFirst("oid")?.Value ?? httpContext.User.FindFirst("sub")?.Value;
if (!string.IsNullOrWhiteSpace(userId))
{
string userIdParamName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(userIdParamName, new(userId));
sessionMapQuery.Append($"EXEC sp_set_session_context 'dab.user_id', {userIdParamName}, @read_only = 0;");
}

// Set tenant identifier for auditing/observability
string? tenantId = httpContext.User.FindFirst("tid")?.Value;
if (!string.IsNullOrWhiteSpace(tenantId))
{
string tenantIdParamName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(tenantIdParamName, new(tenantId));
sessionMapQuery.Append($"EXEC sp_set_session_context 'dab.tenant_id', {tenantIdParamName}, @read_only = 0;");
}
}

return sessionMapQuery.ToString();
}

Expand Down
41 changes: 32 additions & 9 deletions src/Core/Resolvers/OboSqlTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics;
using System.Net;
using System.Security.Claims;
using System.Security.Cryptography;
Expand Down Expand Up @@ -66,21 +67,30 @@ public OboSqlTokenProvider(
{
if (principal is null)
{
_logger.LogWarning("Cannot acquire OBO token: ClaimsPrincipal is null.");
_logger.LogWarning(
"{EventType}: Cannot acquire OBO token - ClaimsPrincipal is null (traceId: {TraceId}).",
"OboValidationFailed",
Activity.Current?.TraceId.ToString() ?? "none");
return null;
}

if (string.IsNullOrWhiteSpace(incomingJwtAssertion))
{
_logger.LogWarning("Cannot acquire OBO token: Incoming JWT assertion is null or empty.");
_logger.LogWarning(
"{EventType}: Cannot acquire OBO token - Incoming JWT assertion is null or empty (traceId: {TraceId}).",
"OboValidationFailed",
Activity.Current?.TraceId.ToString() ?? "none");
return null;
}

// Extract identity claims
string? subjectId = ExtractSubjectId(principal);
if (string.IsNullOrWhiteSpace(subjectId))
{
_logger.LogWarning("Cannot acquire OBO token: Neither 'oid' nor 'sub' claim found in token.");
_logger.LogWarning(
"{EventType}: Cannot acquire OBO token - Neither 'oid' nor 'sub' claim found in token (traceId: {TraceId}).",
"OboValidationFailed",
Activity.Current?.TraceId.ToString() ?? "none");
throw new DataApiBuilderException(
message: DataApiBuilderException.OBO_IDENTITY_CLAIMS_MISSING,
statusCode: HttpStatusCode.Unauthorized,
Expand All @@ -90,7 +100,10 @@ public OboSqlTokenProvider(
string? tenantId = principal.FindFirst("tid")?.Value;
if (string.IsNullOrWhiteSpace(tenantId))
{
_logger.LogWarning("Cannot acquire OBO token: 'tid' (tenant id) claim not found or empty in token.");
_logger.LogWarning(
"{EventType}: Cannot acquire OBO token - 'tid' (tenant id) claim not found or empty in token (traceId: {TraceId}).",
"OboValidationFailed",
Activity.Current?.TraceId.ToString() ?? "none");
throw new DataApiBuilderException(
message: DataApiBuilderException.OBO_TENANT_CLAIM_MISSING,
statusCode: HttpStatusCode.Unauthorized,
Expand All @@ -115,9 +128,11 @@ public OboSqlTokenProvider(
{
wasCacheMiss = true;
_logger.LogInformation(
"OBO token cache MISS for subject {SubjectId} (tenant: {TenantId}). Acquiring new token from Azure AD.",
"{EventType}: OBO token cache MISS for subject {SubjectId} (tenant: {TenantId}, traceId: {TraceId}). Acquiring new token from Azure AD.",
"OboTokenCacheMiss",
subjectId,
tenantId);
tenantId,
Activity.Current?.TraceId.ToString() ?? "none");

AuthenticationResult result = await _msalClient.AcquireTokenOnBehalfOfAsync(
scopes,
Expand All @@ -144,8 +159,10 @@ public OboSqlTokenProvider(
ctx.Options.SetSkipDistributedCache(true, true);

_logger.LogInformation(
"OBO token ACQUIRED for subject {SubjectId}. Expires: {ExpiresOn}, Cache TTL: {CacheDuration}.",
"{EventType}: OBO token ACQUIRED for subject {SubjectId} (traceId: {TraceId}). Expires: {ExpiresOn}, Cache TTL: {CacheDuration}.",
"OboTokenAcquired",
subjectId,
Activity.Current?.TraceId.ToString() ?? "none",
result.ExpiresOn,
cacheDuration);

Expand All @@ -155,7 +172,11 @@ public OboSqlTokenProvider(

if (!string.IsNullOrEmpty(accessToken) && !wasCacheMiss)
{
_logger.LogInformation("OBO token cache HIT for subject {SubjectId}.", subjectId);
_logger.LogInformation(
"{EventType}: OBO token cache HIT for subject {SubjectId} (traceId: {TraceId}).",
"OboTokenCacheHit",
subjectId,
Activity.Current?.TraceId.ToString() ?? "none");
}

return accessToken;
Expand All @@ -164,8 +185,10 @@ public OboSqlTokenProvider(
{
_logger.LogError(
ex,
"Failed to acquire OBO token for subject {SubjectId}. Error: {ErrorCode} - {Message}",
"{EventType}: Failed to acquire OBO token for subject {SubjectId} (traceId: {TraceId}). Error: {ErrorCode} - {Message}",
"OboTokenAcquisitionFailed",
subjectId,
Activity.Current?.TraceId.ToString() ?? "none",
ex.ErrorCode,
ex.Message);
throw new DataApiBuilderException(
Expand Down
200 changes: 200 additions & 0 deletions src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,206 @@ private static Mock<IHttpContextAccessor> CreateHttpContextAccessorWithAuthentic

#endregion

/// <summary>
/// Validates that GetSessionParamsQuery includes all observability values:
/// - OpenTelemetry correlation values (dab.trace_id, dab.span_id) when an Activity is present
/// - OBO observability values (dab.auth_type, dab.user_id, dab.tenant_id) when user-delegated auth is enabled
/// </summary>
[TestMethod, TestCategory(TestCategory.MSSQL)]
public void GetSessionParamsQuery_IncludesAllObservabilityValues_WhenActivityAndOboEnabled()
{
// Arrange
TestHelper.SetupDatabaseEnvironment(TestCategory.MSSQL);

// Create runtime config with user-delegated-auth enabled and set-session-context
RuntimeConfig runtimeConfig = new(
Schema: "UnitTestSchema",
DataSource: new DataSource(
DatabaseType: DatabaseType.MSSQL,
ConnectionString: "Server=localhost;Database=TestDb;",
Options: new Dictionary<string, object> { { "set-session-context", true } })
{
UserDelegatedAuth = new UserDelegatedAuthOptions(
Enabled: true,
Provider: "EntraId",
DatabaseAudience: "https://database.windows.net/")
},
Runtime: new(
Rest: new(),
GraphQL: new(),
Mcp: new(),
Host: new(Cors: null, Authentication: null)),
Entities: new(new Dictionary<string, Entity>()));

MockFileSystem fileSystem = new();
fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(runtimeConfig.ToJson()));
FileSystemRuntimeConfigLoader loader = new(fileSystem);
RuntimeConfigProvider runtimeConfigProvider = new(loader);

Mock<ILogger<QueryExecutor<SqlConnection>>> queryExecutorLogger = new();
Mock<IHttpContextAccessor> httpContextAccessor = new();
DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(runtimeConfigProvider);

MsSqlQueryExecutor msSqlQueryExecutor = new(
runtimeConfigProvider,
dbExceptionParser,
queryExecutorLogger.Object,
httpContextAccessor.Object);

// Create a mock HttpContext with OBO-specific claims (oid, tid, sub)
Mock<HttpContext> mockContext = new();
Mock<HttpRequest> mockRequest = new();
Mock<IHeaderDictionary> mockHeaders = new();

mockHeaders.Setup(h => h["Authorization"]).Returns("Bearer test-token");
mockRequest.Setup(r => r.Headers).Returns(mockHeaders.Object);
mockContext.Setup(c => c.Request).Returns(mockRequest.Object);

var identity = new System.Security.Claims.ClaimsIdentity(
new[]
{
new System.Security.Claims.Claim("oid", "00000000-0000-0000-0000-000000000001"),
new System.Security.Claims.Claim("tid", "11111111-1111-1111-1111-111111111111"),
new System.Security.Claims.Claim("sub", "test-subject")
},
"TestAuth");
var principal = new System.Security.Claims.ClaimsPrincipal(identity);
mockContext.Setup(c => c.User).Returns(principal);

Dictionary<string, DbConnectionParam> parameters = new();

// Act - Create an Activity to simulate OpenTelemetry tracing
using ActivitySource activitySource = new("TestActivitySource");
using ActivityListener listener = new()
{
ShouldListenTo = _ => true,
Sample = (ref ActivityCreationOptions<ActivityContext> _) => ActivitySamplingResult.AllData
};
ActivitySource.AddActivityListener(listener);

using Activity testActivity = activitySource.StartActivity("TestOperation")!;
Assert.IsNotNull(testActivity, "Activity should be created for test");

string sessionParamsQuery = msSqlQueryExecutor.GetSessionParamsQuery(
mockContext.Object,
parameters,
runtimeConfigProvider.GetConfig().DefaultDataSourceName);

// Assert
Assert.IsFalse(string.IsNullOrEmpty(sessionParamsQuery), "Session params query should not be empty");

// Verify OpenTelemetry correlation values are included
Assert.IsTrue(
sessionParamsQuery.Contains("'dab.trace_id'"),
"Session params query should include dab.trace_id");
Assert.IsTrue(
sessionParamsQuery.Contains("'dab.span_id'"),
"Session params query should include dab.span_id");

// Verify the correlation values are in the parameters
Assert.IsTrue(
parameters.Values.Any(p => p.Value?.ToString() == testActivity.TraceId.ToString()),
$"Parameters should contain trace_id value: {testActivity.TraceId}");
Assert.IsTrue(
parameters.Values.Any(p => p.Value?.ToString() == testActivity.SpanId.ToString()),
$"Parameters should contain span_id value: {testActivity.SpanId}");

// Verify OBO-specific observability values are included
Assert.IsTrue(
sessionParamsQuery.Contains("'dab.auth_type'"),
"Session params query should include dab.auth_type for OBO");
Assert.IsTrue(
sessionParamsQuery.Contains("'dab.user_id'"),
"Session params query should include dab.user_id for OBO");
Assert.IsTrue(
sessionParamsQuery.Contains("'dab.tenant_id'"),
"Session params query should include dab.tenant_id for OBO");

// Verify the OBO parameter values are correct
Assert.IsTrue(
parameters.Values.Any(p => p.Value?.ToString() == "obo"),
"Parameters should contain auth_type value: obo");
Assert.IsTrue(
parameters.Values.Any(p => p.Value?.ToString() == "00000000-0000-0000-0000-000000000001"),
"Parameters should contain user_id value (oid)");
Assert.IsTrue(
parameters.Values.Any(p => p.Value?.ToString() == "11111111-1111-1111-1111-111111111111"),
"Parameters should contain tenant_id value");
}

/// <summary>
/// Validates that GetSessionParamsQuery does NOT include correlation values
/// when no Activity is present.
/// </summary>
[TestMethod, TestCategory(TestCategory.MSSQL)]
public void GetSessionParamsQuery_ExcludesCorrelationIds_WhenNoActivity()
{
// Arrange
TestHelper.SetupDatabaseEnvironment(TestCategory.MSSQL);
RuntimeConfig runtimeConfig = new(
Schema: "UnitTestSchema",
DataSource: new DataSource(
DatabaseType: DatabaseType.MSSQL,
ConnectionString: "Server=localhost;Database=TestDb;",
Options: new Dictionary<string, object> { { "set-session-context", true } }),
Runtime: new(
Rest: new(),
GraphQL: new(),
Mcp: new(),
Host: new(Cors: null, Authentication: null)),
Entities: new(new Dictionary<string, Entity>()));

MockFileSystem fileSystem = new();
fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(runtimeConfig.ToJson()));
FileSystemRuntimeConfigLoader loader = new(fileSystem);
RuntimeConfigProvider runtimeConfigProvider = new(loader);

Mock<ILogger<QueryExecutor<SqlConnection>>> queryExecutorLogger = new();
Mock<IHttpContextAccessor> httpContextAccessor = new();
DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(runtimeConfigProvider);

MsSqlQueryExecutor msSqlQueryExecutor = new(
runtimeConfigProvider,
dbExceptionParser,
queryExecutorLogger.Object,
httpContextAccessor.Object);

// Create a mock HttpContext with a simple authenticated user
Mock<HttpContext> mockContext = new();
Mock<HttpRequest> mockRequest = new();
Mock<IHeaderDictionary> mockHeaders = new();

mockHeaders.Setup(h => h["Authorization"]).Returns(string.Empty);
mockRequest.Setup(r => r.Headers).Returns(mockHeaders.Object);
mockContext.Setup(c => c.Request).Returns(mockRequest.Object);

var identity = new System.Security.Claims.ClaimsIdentity(
new[] { new System.Security.Claims.Claim("sub", "test-user") },
"TestAuth");
var principal = new System.Security.Claims.ClaimsPrincipal(identity);
mockContext.Setup(c => c.User).Returns(principal);

Dictionary<string, DbConnectionParam> parameters = new();

// Act - Ensure no Activity is present (Activity.Current should be null)
// We don't start any activity here
string sessionParamsQuery = msSqlQueryExecutor.GetSessionParamsQuery(
mockContext.Object,
parameters,
runtimeConfigProvider.GetConfig().DefaultDataSourceName);

// Assert
Assert.IsFalse(string.IsNullOrEmpty(sessionParamsQuery), "Session params query should not be empty (has user claims)");

// Verify trace_id and span_id are NOT included when no Activity
Assert.IsFalse(
sessionParamsQuery.Contains("'dab.trace_id'"),
"Session params query should NOT include dab.trace_id when no Activity present");
Assert.IsFalse(
sessionParamsQuery.Contains("'dab.span_id'"),
"Session params query should NOT include dab.span_id when no Activity present");
}

[TestCleanup]
public void CleanupAfterEachTest()
{
Expand Down