From 2d3defce9e3fc7a242bbf9f774644a8e326150c4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Feb 2026 22:54:10 +0000 Subject: [PATCH 1/6] Initial plan From aafb94660dc0a2dc514f3783aa7b3d78d74c8358 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:16:44 +0000 Subject: [PATCH 2/6] Fix duplicate OpenAPI tags by reusing global tag instances Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../Services/OpenAPI/OpenApiDocumentor.cs | 47 +++++++++++-------- .../StoredProcedureGeneration.cs | 37 +++++++++++++++ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 87fb96bc32..cec2f12571 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -139,16 +139,22 @@ public void CreateDocument(bool doOverrideExistingDocument = false) }; // Collect all entity tags and their descriptions for the top-level tags array - List globalTags = new(); + // Store tags in a dictionary to ensure we can reuse the same tag instances in BuildPaths + Dictionary globalTagsDict = new(); foreach (KeyValuePair kvp in runtimeConfig.Entities) { Entity entity = kvp.Value; string restPath = entity.Rest?.Path ?? kvp.Key; - globalTags.Add(new OpenApiTag + + // Only add the tag if it hasn't been added yet (handles entities with the same REST path) + if (!globalTagsDict.ContainsKey(restPath)) { - Name = restPath, - Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description - }); + globalTagsDict[restPath] = new OpenApiTag + { + Name = restPath, + Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description + }; + } } OpenApiDocument doc = new() @@ -162,9 +168,9 @@ public void CreateDocument(bool doOverrideExistingDocument = false) { new() { Url = url } }, - Paths = BuildPaths(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName), + Paths = BuildPaths(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName, globalTagsDict), Components = components, - Tags = globalTags + Tags = globalTagsDict.Values.ToList() }; _openApiDocument = doc; } @@ -193,7 +199,7 @@ public void CreateDocument(bool doOverrideExistingDocument = false) /// "/EntityName" /// /// All possible paths in the DAB engine's REST API endpoint. - private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSourceName) + private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSourceName, Dictionary globalTags) { OpenApiPaths pathsCollection = new(); @@ -227,19 +233,22 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour continue; } - // Set the tag's Description property to the entity's semantic description if present. - OpenApiTag openApiTag = new() + // Reuse the existing tag from the global tags dictionary instead of creating a new one + // This ensures Swagger UI displays only one group per entity + List tags = new(); + if (globalTags.TryGetValue(entityRestPath, out OpenApiTag? existingTag)) { - Name = entityRestPath, - Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description - }; - - // The OpenApiTag will categorize all paths created using the entity's name or overridden REST path value. - // The tag categorization will instruct OpenAPI document visualization tooling to display all generated paths together. - List tags = new() + tags.Add(existingTag); + } + else { - openApiTag - }; + // Fallback: create a new tag if not found in global tags (should not happen in normal flow) + tags.Add(new OpenApiTag + { + Name = entityRestPath, + Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description + }); + } Dictionary configuredRestOperations = GetConfiguredRestOperations(entity, dbObject); diff --git a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs index ffd5aaadde..c1e42556a4 100644 --- a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs +++ b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs @@ -149,6 +149,43 @@ public void OpenApiDocumentor_TagsIncludeEntityDescription() $"Expected tag for '{entityName}' with description '{expectedDescription}' not found."); } + /// + /// Integration test validating that there are no duplicate tags in the OpenAPI document. + /// This test ensures that tags created in CreateDocument are reused in BuildPaths, + /// preventing Swagger UI from showing duplicate entity groups. + /// + [TestMethod] + public void OpenApiDocumentor_NosDuplicateTags() + { + // Act: Get the tags from the OpenAPI document + IList tags = _openApiDocument.Tags; + + // Get all tag names + var tagNames = tags.Select(t => t.Name).ToList(); + + // Get distinct tag names + var distinctTagNames = tagNames.Distinct().ToList(); + + // Assert: The number of tags should equal the number of distinct tag names (no duplicates) + Assert.AreEqual(distinctTagNames.Count, tagNames.Count, + $"Duplicate tags found in OpenAPI document. Tags: {string.Join(", ", tagNames)}"); + + // Additionally, verify that each operation references tags that are in the global tags list + foreach (var path in _openApiDocument.Paths) + { + foreach (var operation in path.Value.Operations) + { + foreach (var operationTag in operation.Value.Tags) + { + // Verify that the operation's tag is the same instance as one in the global tags + bool foundMatchingTag = tags.Any(globalTag => ReferenceEquals(globalTag, operationTag)); + Assert.IsTrue(foundMatchingTag, + $"Operation tag '{operationTag.Name}' at path '{path.Key}' is not the same instance as the global tag"); + } + } + } + } + /// /// Validates that the provided OpenApiReference object has the expected schema reference id /// and that that id is present in the list of component schema in the OpenApi document. From db2d784d576a88dc8635dbeb64ae8e63bc5378b5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:17:55 +0000 Subject: [PATCH 3/6] Add missing System.Linq using statement Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Core/Services/OpenAPI/OpenApiDocumentor.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index cec2f12571..46c7e456f5 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; +using System.Linq; using System.Net; using System.Net.Mime; using System.Text; From 9d7ece77b47516d1fbcb3c50ea54021eb26d4ce2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:18:53 +0000 Subject: [PATCH 4/6] Fix typo in test method name Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../OpenApiDocumentor/StoredProcedureGeneration.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs index c1e42556a4..18f696430d 100644 --- a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs +++ b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs @@ -155,7 +155,7 @@ public void OpenApiDocumentor_TagsIncludeEntityDescription() /// preventing Swagger UI from showing duplicate entity groups. /// [TestMethod] - public void OpenApiDocumentor_NosDuplicateTags() + public void OpenApiDocumentor_NoDuplicateTags() { // Act: Get the tags from the OpenAPI document IList tags = _openApiDocument.Tags; From 46f6a5b0289680b4e2d08912ada5339ad8a5a572 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:13:24 +0000 Subject: [PATCH 5/6] Address code review feedback: fix tag key mismatch and improve code quality Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../Services/OpenAPI/OpenApiDocumentor.cs | 19 ++++++++++--------- .../StoredProcedureGeneration.cs | 10 +++++----- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 46c7e456f5..69e009b55f 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -145,17 +145,16 @@ public void CreateDocument(bool doOverrideExistingDocument = false) foreach (KeyValuePair kvp in runtimeConfig.Entities) { Entity entity = kvp.Value; - string restPath = entity.Rest?.Path ?? kvp.Key; + // Use GetEntityRestPath to ensure consistent path computation (with leading slash trimmed) + string restPath = GetEntityRestPath(entity.Rest, kvp.Key); // Only add the tag if it hasn't been added yet (handles entities with the same REST path) - if (!globalTagsDict.ContainsKey(restPath)) + // First entity's description wins when multiple entities share the same REST path. + globalTagsDict.TryAdd(restPath, new OpenApiTag { - globalTagsDict[restPath] = new OpenApiTag - { - Name = restPath, - Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description - }; - } + Name = restPath, + Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description + }); } OpenApiDocument doc = new() @@ -243,7 +242,9 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour } else { - // Fallback: create a new tag if not found in global tags (should not happen in normal flow) + // Fallback: create a new tag if not found in global tags. + // This should not happen in normal flow if GetEntityRestPath is used consistently. + // If this path is reached, it indicates a key mismatch between CreateDocument and BuildPaths. tags.Add(new OpenApiTag { Name = entityRestPath, diff --git a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs index 18f696430d..cb3ede0461 100644 --- a/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs +++ b/src/Service.Tests/OpenApiDocumentor/StoredProcedureGeneration.cs @@ -161,21 +161,21 @@ public void OpenApiDocumentor_NoDuplicateTags() IList tags = _openApiDocument.Tags; // Get all tag names - var tagNames = tags.Select(t => t.Name).ToList(); + List tagNames = tags.Select(t => t.Name).ToList(); // Get distinct tag names - var distinctTagNames = tagNames.Distinct().ToList(); + List distinctTagNames = tagNames.Distinct().ToList(); // Assert: The number of tags should equal the number of distinct tag names (no duplicates) Assert.AreEqual(distinctTagNames.Count, tagNames.Count, $"Duplicate tags found in OpenAPI document. Tags: {string.Join(", ", tagNames)}"); // Additionally, verify that each operation references tags that are in the global tags list - foreach (var path in _openApiDocument.Paths) + foreach (KeyValuePair path in _openApiDocument.Paths) { - foreach (var operation in path.Value.Operations) + foreach (KeyValuePair operation in path.Value.Operations) { - foreach (var operationTag in operation.Value.Tags) + foreach (OpenApiTag operationTag in operation.Value.Tags) { // Verify that the operation's tag is the same instance as one in the global tags bool foundMatchingTag = tags.Any(globalTag => ReferenceEquals(globalTag, operationTag)); From 004478248ed57ae181ba9c98b09b6125b52d455a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:40:29 +0000 Subject: [PATCH 6/6] Merge main and fix duplicate OpenAPI tags: reuse shared tag instances, use GetEntityRestPath consistently, log warning on key mismatch Co-authored-by: souvikghosh04 <210500244+souvikghosh04@users.noreply.github.com> --- .github/copilot-instructions.md | 251 +++++++ .gitignore | 1 + .pipelines/cosmos-pipelines.yml | 1 + .pipelines/dwsql-pipelines.yml | 17 +- .pipelines/mssql-pipelines.yml | 17 +- .pipelines/mysql-pipelines.yml | 7 +- .pipelines/pg-pipelines.yml | 7 +- CODEOWNERS | 2 +- Dockerfile | 2 + Nuget.config | 2 +- dab-config.json | 58 ++ global.json | 2 +- schemas/dab.draft.schema.json | 48 +- .../Azure.DataApiBuilder.Mcp.csproj | 8 +- .../BuiltInTools/CreateRecordTool.cs | 20 +- .../BuiltInTools/DescribeEntitiesTool.cs | 41 ++ .../BuiltInTools/ReadRecordsTool.cs | 6 + .../BuiltInTools/UpdateRecordTool.cs | 6 + .../Core/DynamicCustomTool.cs | 34 +- .../Core/McpServerConfiguration.cs | 121 ++-- .../Core/McpStdioServer.cs | 9 +- .../Core/McpToolRegistry.cs | 42 +- .../Utils/McpResponseBuilder.cs | 4 +- .../Utils/McpTelemetryErrorCodes.cs | 41 ++ .../Utils/McpTelemetryHelper.cs | 260 +++++++ src/Cli.Tests/AutoConfigSimulateTests.cs | 239 +++++++ src/Cli.Tests/AutoConfigTests.cs | 291 ++++++++ src/Cli.Tests/ConfigGeneratorTests.cs | 1 + src/Cli.Tests/ConfigureOptionsTests.cs | 250 +++++++ src/Cli.Tests/ModuleInitializer.cs | 4 + src/Cli.Tests/TestHelper.cs | 41 ++ .../UserDelegatedAuthRuntimeParsingTests.cs | 101 +++ src/Cli/Commands/AutoConfigOptions.cs | 104 +++ src/Cli/Commands/AutoConfigSimulateOptions.cs | 49 ++ src/Cli/Commands/ConfigureOptions.cs | 21 + src/Cli/ConfigGenerator.cs | 593 +++++++++++++++- src/Cli/Program.cs | 4 +- src/Config/Converters/AutoentityConverter.cs | 3 +- .../Converters/AutoentityTemplateConverter.cs | 12 +- .../CompressionOptionsConverterFactory.cs | 104 +++ .../Converters/DataSourceConverterFactory.cs | 20 +- ...DatasourceHealthOptionsConvertorFactory.cs | 16 +- .../RuntimeAutoentitiesConverter.cs | 2 +- src/Config/DataApiBuilderException.cs | 11 +- .../DatabasePrimitives/DatabaseObject.cs | 8 +- src/Config/ObjectModel/CompressionLevel.cs | 28 + src/Config/ObjectModel/CompressionOptions.cs | 46 ++ src/Config/ObjectModel/DataSource.cs | 78 ++ src/Config/ObjectModel/RuntimeAutoentities.cs | 21 +- src/Config/ObjectModel/RuntimeConfig.cs | 121 ++-- src/Config/ObjectModel/RuntimeOptions.cs | 5 +- src/Config/RuntimeConfigLoader.cs | 6 + .../Configurations/RuntimeConfigProvider.cs | 15 + .../Configurations/RuntimeConfigValidator.cs | 122 +++- .../RuntimeConfigValidatorUtil.cs | 88 +++ src/Core/Models/GraphQLFilterParsers.cs | 1 + .../RestRequestContexts/RestRequestContext.cs | 6 + src/Core/Parsers/RequestParser.cs | 63 +- .../Factories/QueryManagerFactory.cs | 10 +- src/Core/Resolvers/IMsalClientWrapper.cs | 25 + src/Core/Resolvers/IOboTokenProvider.cs | 33 + src/Core/Resolvers/IQueryBuilder.cs | 2 + src/Core/Resolvers/MsSqlQueryBuilder.cs | 185 ++++- src/Core/Resolvers/MsSqlQueryExecutor.cs | 246 ++++++- src/Core/Resolvers/MsalClientWrapper.cs | 37 + src/Core/Resolvers/OboSqlTokenProvider.cs | 246 +++++++ src/Core/Resolvers/SqlMutationEngine.cs | 78 ++ .../CosmosSqlMetadataProvider.cs | 3 +- .../MetadataProviderFactory.cs | 13 +- .../MsSqlMetadataProvider.cs | 120 +++- .../MySqlMetadataProvider.cs | 3 +- .../PostgreSqlMetadataProvider.cs | 3 +- .../MetadataProviders/SqlMetadataProvider.cs | 64 +- .../Services/OpenAPI/IOpenApiDocumentor.cs | 9 + .../Services/OpenAPI/OpenApiDocumentor.cs | 671 +++++++++++++----- src/Core/Services/RequestValidator.cs | 49 +- src/Core/Services/RestService.cs | 88 ++- src/Core/Telemetry/TelemetryTracesHelper.cs | 78 +- src/Directory.Build.props | 2 +- src/Directory.Packages.props | 19 +- .../Helpers/RuntimeConfigAuthHelper.cs | 27 + .../JwtTokenAuthenticationUnitTests.cs | 10 +- .../OboSqlTokenProviderUnitTests.cs | 378 ++++++++++ .../Azure.DataApiBuilder.Service.Tests.csproj | 1 + .../Caching/CachingConfigProcessingTests.cs | 44 ++ .../DabCacheServiceIntegrationTests.cs | 5 + .../CompressionIntegrationTests.cs | 300 ++++++++ .../Configuration/ConfigurationTests.cs | 425 ++++++++++- .../Configuration/RuntimeConfigLoaderTests.cs | 30 + .../Telemetry/AzureLogAnalyticsTests.cs | 11 +- .../CosmosTests/QueryFilterTests.cs | 25 + src/Service.Tests/CosmosTests/TestBase.cs | 6 +- .../GraphQLBuilder/MultiSourceBuilderTests.cs | 5 +- .../MultipleMutationBuilderTests.cs | 12 + .../Mcp/DescribeEntitiesFilteringTests.cs | 504 +++++++++++++ .../EntityLevelDmlToolConfigurationTests.cs | 175 ++++- src/Service.Tests/Mcp/McpToolRegistryTests.cs | 337 +++++++++ src/Service.Tests/ModuleInitializer.cs | 4 + src/Service.Tests/Multidab-config.MsSql.json | 78 +- .../DocumentVerbosityTests.cs | 12 +- .../OpenApiDocumentor/FieldFilteringTests.cs | 80 +++ .../OpenApiDocumentor/OpenApiTestBootstrap.cs | 17 +- .../OperationFilteringTests.cs | 134 ++++ .../ParameterValidationTests.cs | 9 +- .../OpenApiDocumentor/PathValidationTests.cs | 1 - .../RequestBodyStrictTests.cs | 79 +++ .../OpenApiDocumentor/RoleIsolationTests.cs | 186 +++++ .../RoleSpecificEndpointTests.cs | 213 ++++++ .../RestApiTests/Find/DwSqlFindApiTests.cs | 6 + .../RestApiTests/Find/FindApiTestBase.cs | 17 + .../RestApiTests/Find/MsSqlFindApiTests.cs | 6 + .../RestApiTests/Find/MySqlFindApiTests.cs | 12 + .../Find/PostgreSqlFindApiTests.cs | 11 + .../RestApiTests/Patch/MsSqlPatchApiTests.cs | 7 + .../RestApiTests/Patch/MySqlPatchApiTests.cs | 11 + .../RestApiTests/Patch/PatchApiTestBase.cs | 74 +- .../Patch/PostgreSqlPatchApiTests.cs | 12 + .../RestApiTests/Put/MsSqlPutApiTests.cs | 7 + .../RestApiTests/Put/MySqlPutApiTests.cs | 12 + .../RestApiTests/Put/PostgreSqlPutApiTests.cs | 12 + .../RestApiTests/Put/PutApiTestBase.cs | 97 ++- src/Service.Tests/SqlTests/SqlTestBase.cs | 9 + .../UnitTests/ConfigValidationUnitTests.cs | 551 +++++++++++++- .../HealthCheckUtilitiesUnitTests.cs | 305 ++++++++ .../UnitTests/McpTelemetryTests.cs | 382 ++++++++++ .../UnitTests/RequestParserUnitTests.cs | 80 +++ .../UnitTests/RestServiceUnitTests.cs | 125 +++- .../UnitTests/SqlMetadataProviderUnitTests.cs | 82 ++- .../UnitTests/SqlQueryExecutorUnitTests.cs | 352 ++++++++- src/Service.Tests/UnitTests/StartupTests.cs | 32 + src/Service.Tests/dab-config.DwSql.json | 2 +- src/Service.Tests/dab-config.MsSql.json | 2 +- .../Azure.DataApiBuilder.Service.csproj | 1 + src/Service/Controllers/RestController.cs | 28 + ...ComprehensiveHealthReportResponseWriter.cs | 35 +- src/Service/HealthCheck/HealthCheckHelper.cs | 70 +- src/Service/HealthCheck/HttpUtilities.cs | 4 +- .../Model/ComprehensiveHealthCheckReport.cs | 6 + src/Service/HealthCheck/Utilities.cs | 29 + src/Service/Startup.cs | 172 ++++- 140 files changed, 10528 insertions(+), 596 deletions(-) create mode 100644 .github/copilot-instructions.md create mode 100644 dab-config.json create mode 100644 src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs create mode 100644 src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs create mode 100644 src/Cli.Tests/AutoConfigSimulateTests.cs create mode 100644 src/Cli.Tests/AutoConfigTests.cs create mode 100644 src/Cli.Tests/UserDelegatedAuthRuntimeParsingTests.cs create mode 100644 src/Cli/Commands/AutoConfigOptions.cs create mode 100644 src/Cli/Commands/AutoConfigSimulateOptions.cs create mode 100644 src/Config/Converters/CompressionOptionsConverterFactory.cs create mode 100644 src/Config/ObjectModel/CompressionLevel.cs create mode 100644 src/Config/ObjectModel/CompressionOptions.cs create mode 100644 src/Core/Resolvers/IMsalClientWrapper.cs create mode 100644 src/Core/Resolvers/IOboTokenProvider.cs create mode 100644 src/Core/Resolvers/MsalClientWrapper.cs create mode 100644 src/Core/Resolvers/OboSqlTokenProvider.cs create mode 100644 src/Service.Tests/Authentication/OboSqlTokenProviderUnitTests.cs create mode 100644 src/Service.Tests/Configuration/CompressionIntegrationTests.cs create mode 100644 src/Service.Tests/Mcp/DescribeEntitiesFilteringTests.cs create mode 100644 src/Service.Tests/Mcp/McpToolRegistryTests.cs create mode 100644 src/Service.Tests/OpenApiDocumentor/FieldFilteringTests.cs create mode 100644 src/Service.Tests/OpenApiDocumentor/OperationFilteringTests.cs create mode 100644 src/Service.Tests/OpenApiDocumentor/RequestBodyStrictTests.cs create mode 100644 src/Service.Tests/OpenApiDocumentor/RoleIsolationTests.cs create mode 100644 src/Service.Tests/OpenApiDocumentor/RoleSpecificEndpointTests.cs create mode 100644 src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs create mode 100644 src/Service.Tests/UnitTests/McpTelemetryTests.cs create mode 100644 src/Service.Tests/UnitTests/RequestParserUnitTests.cs create mode 100644 src/Service.Tests/UnitTests/StartupTests.cs diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000..0ac6e54eef --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,251 @@ +# Data API builder (DAB) - Copilot Instructions + +## Project Overview + +Data API builder (DAB) is an open-source, no-code tool that creates secure, full-featured REST, GraphQL endpoints for databases. It's a CRUD data API engine that runs in a container—on Azure, any other cloud, or on-premises. It also supports creation of DML and Custom MCP tools to build a SQL MCP Server backed by a SQL database. + +### Key Technologies +- **Language**: C# / .NET +- **.NET Version**: .NET 8.0 (see `global.json`) +- **Supported Databases**: Azure SQL, SQL Server, SQLDW, Cosmos DB, PostgreSQL, MySQL +- **API Types**: REST, GraphQL, MCP +- **Deployment**: Cross-platform (Azure, AWS, GCP, on-premises) + +## Project Structure + +``` +data-api-builder/ +├── src/ +│ ├── Auth/ # Authentication logic +│ ├── Cli/ # Command-line interface (dab CLI) +│ ├── Cli.Tests/ # CLI tests +│ ├── Config/ # Configuration handling +│ ├── Core/ # Core engine components +│ ├── Service/ # Main DAB service/runtime +│ ├── Service.GraphQLBuilder/ # GraphQL schema builder +│ ├── Service.Tests/ # Integration tests +│ └── Azure.DataApiBuilder.sln # Main solution file +├── config-generators/ # Config file generation helpers +├── docs/ # Documentation +├── samples/ # Sample configurations and projects +├── schemas/ # JSON schemas for config validation +├── scripts/ # Build and utility scripts +└── templates/ # Project templates +``` + +## Building and Testing + +### Prerequisites +- .NET 8.0 SDK or later +- Database server for testing (SQL Server, PostgreSQL, MySQL, or Cosmos DB) + +### Building the Project + +```bash +# Build the entire solution +dotnet build src/Azure.DataApiBuilder.sln + +# Clean and rebuild +dotnet clean src/Azure.DataApiBuilder.sln +dotnet build src/Azure.DataApiBuilder.sln +``` + +### Running Tests + +DAB uses integration tests that require database instances with proper schemas. + +**SQL-based tests:** +```bash +# MsSql tests +dotnet test --filter "TestCategory=MsSql" + +# PostgreSQL tests +dotnet test --filter "TestCategory=PostgreSql" + +# MySQL tests +dotnet test --filter "TestCategory=MySql" +``` + +**CosmosDB tests:** +```bash +dotnet test --filter "TestCategory=CosmosDb_NoSql" +``` + +**Test Configuration:** +- Test database schemas are in `src/Service.Tests/DatabaseSchema-.sql` +- Config files are `src/Service.Tests/dab-config..json` +- Connection strings should use `@env('variable_name')` syntax - never commit connection strings + +### Running Locally + +1. Open the solution: `src/Azure.DataApiBuilder.sln` +2. Copy a config file from `src/Service.Tests/dab-config..json` to `src/Service/` +3. Update connection string (use environment variables) +4. Set `Azure.DataApiBuilder.Service` as startup project +5. Select debug profile: `MsSql`, `PostgreSql`, `CosmosDb_NoSql`, or `MySql` +6. Build and run + +## Code Style and Conventions + +### Formatting +- **Tool**: `dotnet format` (enforced in CI) +- **Indentation**: 4 spaces for C# code, 2 spaces for YAML/JSON +- **Line endings**: LF (Unix-style) +- **Character encoding**: UTF-8 +- **Trailing whitespace**: Removed +- **Final newline**: Required +- Refer to `.\src\.editorconfig` for additional formatting conventions. + +### Running Code Formatter + +```bash +# Format all files +dotnet format src/Azure.DataApiBuilder.sln + +# Verify formatting (CI check) +dotnet format src/Azure.DataApiBuilder.sln --verify-no-changes +``` + +### C# Conventions +- **Usings**: Sort system directives first, no separation between groups +- **Type preferences**: Use language keywords (`int`, `string`) over BCL types (`Int32`, `String`) +- **Naming**: Follow standard .NET naming conventions +- **`this.` qualifier**: Not used unless necessary + +### SQL Query Formatting +When adding or modifying generated SQL queries in tests: +- **PostgreSQL**: Use https://sqlformat.org/ (remove unnecessary double quotes) +- **SQL Server**: Use https://poorsql.com/ (enable "trailing commas", indent string: `\s\s\s\s`) +- **MySQL**: Use https://poorsql.com/ (same as SQL Server, max line width: 100) + +## Testing Guidelines + +### Test Organization +- Integration tests validate the engine's query generation and database operations +- Tests are organized by database type using TestCategory attributes +- Each database type has its own config file and schema + +### Adding New Tests +- Work within the existing database schema (SQL) or GraphQL schema (CosmosDB) +- Add tests to the appropriate test class +- Use base class methods and helpers for engine operations +- Format any generated SQL queries using the specified formatters +- Do not commit connection strings to the repository + +### Test Database Setup +1. Create database using the appropriate server +2. Run the schema script: `src/Service.Tests/DatabaseSchema-.sql` +3. Set connection string in config using `@env()` syntax +4. Run tests for that specific database type + +## Configuration Files + +### DAB Configuration +- Config files use JSON format with schema validation +- Schema files are in the `schemas/` directory +- Use `@env('variable_name')` to reference environment variables +- Never commit connection strings or secrets + +### Config Generation +Use the config-generators directory for automated config file creation: +```bash +# Build with config generation +dotnet build -p:generateConfigFileForDbType= +``` +Supported types: `mssql`, `postgresql`, `cosmosdb_nosql`, `mysql` + +## Security Practices + +- **Never commit secrets**: Use environment variables with `@env()` syntax +- **Connection strings**: Always use `.env` files (add to `.gitignore`) +- **Authentication**: Supports AppService, EasyAuth, StaticWebApps, JWT +- **Authorization**: Role-based permissions in config +- **set-session-context**: Available for SQL Server row-level security + +## API Development + +### REST API +- Base path: `/api` (configurable) +- Follows Microsoft REST API Guidelines +- Request body validation available +- Health endpoint: `/health` +- Swagger UI in development mode: `/{REST_PATH}/openapi` (default: `/api/openapi`) + +### GraphQL API +- Base path: `/graphql` (configurable) +- Introspection enabled in development mode +- Nitro UI in development mode: `/graphql` +- Schema generated from database metadata +### MCP Tools +- Base Path: `/mcp` (configurable) +- Discover tools with MCP Inspector + +## Common Commands + +```bash +# Install DAB CLI globally +dotnet tool install microsoft.dataapibuilder -g + +# Initialize a new config +dab init --database-type --connection-string "@env('connection_string')" --host-mode development + +# Add an entity to config +dab add --source --permissions "anonymous:*" + +# Start DAB locally +dab start + +# Validate a config file +dab validate +``` + +## Contributing + +- Sign the Contributor License Agreement (CLA) +- Follow the Microsoft Open Source Code of Conduct +- Use issue templates when reporting bugs or requesting features +- Include configuration files, logs, and hosting model in issue reports +- Run `dotnet format` before committing +- Do not commit connection strings or other secrets + +### Commit Signing + +All commits should be signed to receive the verified badge on GitHub. Configure GPG or SSH signing: + +**GPG Signing:** +```bash +# Generate a GPG key +gpg --full-generate-key + +# List keys and copy the key ID +gpg --list-secret-keys --keyid-format=long + +# Configure Git to use the key +git config --global user.signingkey +git config --global commit.gpgsign true + +# Add GPG key to GitHub account +gpg --armor --export +``` + +**SSH Signing:** +```bash +# Generate an SSH key +ssh-keygen -t ed25519 -C "your_email@example.com" + +# Configure Git to use SSH signing +git config --global gpg.format ssh +git config --global user.signingkey ~/.ssh/id_ed25519.pub +git config --global commit.gpgsign true + +# Add SSH key to GitHub account as signing key +``` + +## References + +- [Official Documentation](https://learn.microsoft.com/azure/data-api-builder/) +- [Samples](https://aka.ms/dab/samples) +- [Known Issues](https://learn.microsoft.com/azure/data-api-builder/known-issues) +- [Feature Roadmap](https://github.com/Azure/data-api-builder/discussions/1377) +- [Microsoft REST API Guidelines](https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md) +- [GraphQL Specification](https://graphql.org/) diff --git a/.gitignore b/.gitignore index 56bd0e435d..4de9ba1a81 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ dab-config*.json *.dab-config.json !dab-config.*reference.json !dab-config.*.example.json +!/dab-config.json *.cd diff --git a/.pipelines/cosmos-pipelines.yml b/.pipelines/cosmos-pipelines.yml index 61d066aac7..c664879288 100644 --- a/.pipelines/cosmos-pipelines.yml +++ b/.pipelines/cosmos-pipelines.yml @@ -69,6 +69,7 @@ steps: projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config + restoreArguments: '/p:RuntimeIdentifiers=""' - task: DotNetCoreCLI@2 displayName: Build diff --git a/.pipelines/dwsql-pipelines.yml b/.pipelines/dwsql-pipelines.yml index 767c6c7df9..bc1e36dd6d 100644 --- a/.pipelines/dwsql-pipelines.yml +++ b/.pipelines/dwsql-pipelines.yml @@ -61,12 +61,7 @@ jobs: projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config - - - task: DockerInstaller@0 - displayName: Docker Installer - inputs: - dockerVersion: 17.09.0-ce - releaseType: stable + restoreArguments: '/p:RuntimeIdentifiers=""' - task: Bash@3 displayName: 'Generate password' @@ -81,7 +76,7 @@ jobs: inputs: targetType: 'inline' script: | - connectionString="Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=SA;Password=$(dbPassword);MultipleActiveResultSets=False;Connection Timeout=5;TrustServerCertificate=True;Encrypt=False;" + connectionString="Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=SA;Password=$(dbPassword);MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True;Encrypt=False;" echo "##vso[task.setvariable variable=data-source.connection-string;]$connectionString" - task: Bash@3 @@ -163,7 +158,7 @@ jobs: # since windows needs a different string. # The variable setting on the pipeline UI sets the connection string # for the linux job above. - data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=5;TrustServerCertificate=True; + data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True; InstallerUrl: https://download.microsoft.com/download/7/c/1/7c14e92e-bdcb-4f89-b7cf-93543e7112d1/SqlLocalDB.msi SqlVersionCode: '15.0' @@ -188,12 +183,14 @@ jobs: - task: NuGetToolInstaller@1 - - task: NuGetCommand@2 + - task: DotNetCoreCLI@2 displayName: Restore NuGet packages inputs: - restoreSolution: '$(solution)' + command: restore + projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config + restoreArguments: '/p:RuntimeIdentifiers=""' - task: PowerShell@2 displayName: Install SQL LocalDB # Update when clarity on how to setup diff --git a/.pipelines/mssql-pipelines.yml b/.pipelines/mssql-pipelines.yml index af823db200..d2d2d8c7fa 100644 --- a/.pipelines/mssql-pipelines.yml +++ b/.pipelines/mssql-pipelines.yml @@ -62,12 +62,7 @@ jobs: projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config - - - task: DockerInstaller@0 - displayName: Docker Installer - inputs: - dockerVersion: 17.09.0-ce - releaseType: stable + restoreArguments: '/p:RuntimeIdentifiers=""' - task: Bash@3 displayName: 'Generate password' @@ -82,7 +77,7 @@ jobs: inputs: targetType: 'inline' script: | - $connectionString="Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=SA;Password=$(dbPassword);MultipleActiveResultSets=False;Connection Timeout=5;TrustServerCertificate=True;Encrypt=False;" + $connectionString="Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=SA;Password=$(dbPassword);MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True;Encrypt=False;" Write-Host "##vso[task.setvariable variable=data-source.connection-string]$connectionString" - task: Bash@3 @@ -167,7 +162,7 @@ jobs: # since windows needs a different string. # The variable setting on the pipeline UI sets the connection string # for the linux job above. - data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=5;TrustServerCertificate=True; + data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True; InstallerUrl: https://download.microsoft.com/download/7/c/1/7c14e92e-bdcb-4f89-b7cf-93543e7112d1/SqlLocalDB.msi SqlVersionCode: '15.0' @@ -192,12 +187,14 @@ jobs: - task: NuGetToolInstaller@1 - - task: NuGetCommand@2 + - task: DotNetCoreCLI@2 displayName: Restore NuGet packages inputs: - restoreSolution: '$(solution)' + command: restore + projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config + restoreArguments: '/p:RuntimeIdentifiers=""' - task: PowerShell@2 displayName: Install SQL LocalDB diff --git a/.pipelines/mysql-pipelines.yml b/.pipelines/mysql-pipelines.yml index 7377d5b69c..e6fe59d6c0 100644 --- a/.pipelines/mysql-pipelines.yml +++ b/.pipelines/mysql-pipelines.yml @@ -60,12 +60,7 @@ jobs: projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config - - - task: DockerInstaller@0 - displayName: Docker Installer - inputs: - dockerVersion: 17.09.0-ce - releaseType: stable + restoreArguments: '/p:RuntimeIdentifiers=""' - task: Bash@3 displayName: 'Generate password' diff --git a/.pipelines/pg-pipelines.yml b/.pipelines/pg-pipelines.yml index d6b7600e1f..6cd223aca6 100644 --- a/.pipelines/pg-pipelines.yml +++ b/.pipelines/pg-pipelines.yml @@ -55,12 +55,7 @@ jobs: projects: '$(solution)' feedsToUse: config nugetConfigPath: Nuget.config - - - task: DockerInstaller@0 - displayName: Docker Installer - inputs: - dockerVersion: 17.09.0-ce - releaseType: stable + restoreArguments: '/p:RuntimeIdentifiers=""' - task: Bash@3 displayName: 'Generate password' diff --git a/CODEOWNERS b/CODEOWNERS index ac65d7bc52..53b1028969 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, # review when someone opens a pull request. -* @Aniruddh25 @aaronburtle @anushakolan @RubenCerna2079 @souvikghosh04 @neeraj-sharma2592 @sourabh1007 @vadeveka @Alekhya-Polavarapu @rusamant @stuartpa +* @Aniruddh25 @aaronburtle @anushakolan @jerrynixon @RubenCerna2079 @souvikghosh04 @sourabh1007 @vadeveka @Alekhya-Polavarapu @rusamant @stuartpa code_of_conduct.md @jerrynixon contributing.md @jerrynixon diff --git a/Dockerfile b/Dockerfile index d6d950733c..59190b6374 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,8 @@ RUN dotnet build "./src/Service/Azure.DataApiBuilder.Service.csproj" -c Docker - FROM mcr.microsoft.com/dotnet/aspnet:8.0-cbl-mariner2.0 AS runtime COPY --from=build /out /App +# Add default dab-config.json to /App in the image +COPY dab-config.json /App/dab-config.json WORKDIR /App ENV ASPNETCORE_URLS=http://+:5000 ENTRYPOINT ["dotnet", "Azure.DataApiBuilder.Service.dll"] diff --git a/Nuget.config b/Nuget.config index 704c9d13ba..8235dd6afb 100644 --- a/Nuget.config +++ b/Nuget.config @@ -1,4 +1,4 @@ - + diff --git a/dab-config.json b/dab-config.json new file mode 100644 index 0000000000..adc4f37fb8 --- /dev/null +++ b/dab-config.json @@ -0,0 +1,58 @@ +{ + "$schema": "https://github.com/Azure/data-api-builder/releases/latest/download/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@env('DAB_CONNSTRING')", + "options": { + "set-session-context": false + } + }, + "runtime": { + "rest": { + "enabled": true, + "path": "/api", + "request-body-strict": false + }, + "graphql": { + "enabled": true, + "path": "/graphql", + "allow-introspection": true + }, + "mcp": { + "enabled": true, + "path": "/mcp" + }, + "host": { + "cors": { + "origins": [], + "allow-credentials": false + }, + "authentication": { + "provider": "Simulator" + }, + "mode": "development" + } + }, + "entities": {}, + "autoentities": { + "default": { + "template": { + "mcp": { "dml-tools": true }, + "rest": { "enabled": true }, + "graphql": { "enabled": true }, + "health": { "enabled": true }, + "cache": { + "enabled": false + } + }, + "permissions": [ + { + "role": "anonymous", + "actions": [ + "create", "read", "update", "delete" + ] + } + ] + } + } +} diff --git a/global.json b/global.json index 1bdb496ef0..17811390a4 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "8.0.417", + "version": "8.0.418", "rollForward": "latestFeature" } } diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 920c0a4da6..cbe38b7d72 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -61,6 +61,28 @@ "maximum": 2147483647 } } + }, + "user-delegated-auth": { + "description": "User-delegated authentication configuration for On-Behalf-Of (OBO) flow. Enables DAB to connect to the database using the calling user's identity.", + "type": ["object", "null"], + "additionalProperties": false, + "properties": { + "enabled": { + "$ref": "#/$defs/boolean-or-string", + "description": "Enable user-delegated authentication (OBO flow).", + "default": false + }, + "provider": { + "type": "string", + "description": "Identity provider for user-delegated authentication.", + "enum": ["EntraId"], + "default": "EntraId" + }, + "database-audience": { + "type": "string", + "description": "The audience URI for the target database (e.g., https://database.windows.net for Azure SQL)." + } + } } }, "allOf": [ @@ -424,6 +446,19 @@ } } }, + "compression": { + "type": "object", + "description": "Configures HTTP response compression settings.", + "additionalProperties": false, + "properties": { + "level": { + "type": "string", + "enum": ["optimal", "fastest", "none"], + "default": "optimal", + "description": "Specifies the response compression level. 'optimal' provides best compression ratio, 'fastest' prioritizes speed, 'none' disables compression." + } + } + }, "telemetry": { "type": "object", "description": "Telemetry configuration", @@ -800,7 +835,7 @@ "level": { "type": "string", "description": "Cache level (L1 or L1L2)", - "enum": [ "L1", "L1L2", null ], + "enum": [ "L1", "L1L2" ], "default": "L1L2" } } @@ -1116,9 +1151,16 @@ "default": false }, "ttl-seconds": { - "type": "integer", + "type": [ "integer", "null" ], "description": "Time to live in seconds", - "default": 5 + "default": 5, + "minimum": 1 + }, + "level": { + "type": "string", + "description": "Cache level (L1 or L1L2)", + "enum": [ "L1", "L1L2" ], + "default": "L1L2" } } }, diff --git a/src/Azure.DataApiBuilder.Mcp/Azure.DataApiBuilder.Mcp.csproj b/src/Azure.DataApiBuilder.Mcp/Azure.DataApiBuilder.Mcp.csproj index f675f8d8d1..c1e4f9cfe4 100644 --- a/src/Azure.DataApiBuilder.Mcp/Azure.DataApiBuilder.Mcp.csproj +++ b/src/Azure.DataApiBuilder.Mcp/Azure.DataApiBuilder.Mcp.csproj @@ -1,11 +1,17 @@ - + net8.0 enable enable + + $(NoWarn);NU1603 + + + + diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs index 5ac12f5988..cab4b69bdb 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs @@ -124,17 +124,23 @@ public async Task ExecuteAsync( } JsonElement insertPayloadRoot = dataElement.Clone(); + + // Validate it's a table or view - stored procedures use execute_entity + if (dbObject.SourceType != EntitySourceType.Table && dbObject.SourceType != EntitySourceType.View) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity '{entityName}' is not a table or view. For stored procedures, use the execute_entity tool instead.", logger); + } + InsertRequestContext insertRequestContext = new( entityName, dbObject, insertPayloadRoot, EntityActionOperation.Insert); - RequestValidator requestValidator = serviceProvider.GetRequiredService(); - - // Only validate tables + // Only validate tables. For views, skip validation and let the database handle any errors. if (dbObject.SourceType is EntitySourceType.Table) { + RequestValidator requestValidator = serviceProvider.GetRequiredService(); try { requestValidator.ValidateInsertRequestContext(insertRequestContext); @@ -144,14 +150,6 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger); } } - else - { - return McpResponseBuilder.BuildErrorResult( - toolName, - "InvalidCreateTarget", - "The create_record tool is only available for tables.", - logger); - } IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); DatabaseType databaseType = sqlMetadataProvider.GetDatabaseType(); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs index c5b283214f..95e37c4498 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs @@ -146,6 +146,10 @@ public Task ExecuteAsync( List> entityList = new(); + // Track how many entities were filtered out because DML tools are disabled (dml-tools: false). + // This helps provide a more specific error message when all entities are filtered. + int filteredDmlDisabledCount = 0; + if (runtimeConfig.Entities != null) { foreach (KeyValuePair entityEntry in runtimeConfig.Entities) @@ -155,11 +159,22 @@ public Task ExecuteAsync( string entityName = entityEntry.Key; Entity entity = entityEntry.Value; + // Check entity filter first to avoid counting entities that wouldn't be included anyway if (!ShouldIncludeEntity(entityName, entityFilter)) { continue; } + // Filter out entities when dml-tools is explicitly disabled (false). + // This applies to all entity types (tables, views, stored procedures). + // When dml-tools is false, the entity is not exposed via DML tools + // (read_records, create_record, etc.) and should not appear in describe_entities. + if (entity.Mcp?.DmlToolEnabled == false) + { + filteredDmlDisabledCount++; + continue; + } + try { Dictionary entityInfo = nameOnly @@ -177,6 +192,7 @@ public Task ExecuteAsync( if (entityList.Count == 0) { + // No entities matched the filter criteria if (entityFilter != null && entityFilter.Count > 0) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( @@ -185,6 +201,20 @@ public Task ExecuteAsync( $"No entities found matching the filter: {string.Join(", ", entityFilter)}", logger)); } + // Return a specific error when ALL configured entities have dml-tools: false. + // Only show this error when every entity was intentionally filtered by the dml-tools check above, + // not when some entities failed to build due to exceptions in BuildBasicEntityInfo() or BuildFullEntityInfo() functions. + else if (filteredDmlDisabledCount > 0 && + runtimeConfig.Entities != null && + filteredDmlDisabledCount == runtimeConfig.Entities.Entities.Count) + { + return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, + "AllEntitiesFilteredDmlDisabled", + $"All {filteredDmlDisabledCount} configured entities have DML tools disabled (dml-tools: false). Entities with dml-tools disabled do not appear in describe_entities. If the filtered entities are stored procedures with custom-tool enabled, check tools/list.", + logger)); + } + // Truly no entities configured in the runtime config, or entities failed to build for other reasons else { return Task.FromResult(McpResponseBuilder.BuildErrorResult( @@ -207,6 +237,17 @@ public Task ExecuteAsync( ["count"] = finalEntityList.Count }; + // Log when entities were filtered due to DML tools disabled for visibility + if (filteredDmlDisabledCount > 0) + { + logger?.LogInformation( + "DescribeEntitiesTool: {FilteredCount} entity(ies) filtered with DML tools disabled (dml-tools: false). " + + "These entities are not exposed via DML tools and do not appear in describe_entities response. " + + "Returned {ReturnedCount} entities.", + filteredDmlDisabledCount, + finalEntityList.Count); + } + logger?.LogInformation( "DescribeEntitiesTool returned {EntityCount} entities. Response type: {ResponseType} (nameOnly={NameOnly}).", finalEntityList.Count, diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs index 64e73f0281..dbbc338c76 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -158,6 +158,12 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } + // Validate it's a table or view + if (dbObject.SourceType != EntitySourceType.Table && dbObject.SourceType != EntitySourceType.View) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity '{entityName}' is not a table or view. For stored procedures, use the execute_entity tool instead.", logger); + } + // Authorization check in the existing entity IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index ed2a9f3ce4..883ddde02e 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -137,6 +137,12 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } + // Validate it's a table or view + if (dbObject.SourceType != EntitySourceType.Table && dbObject.SourceType != EntitySourceType.View) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity '{entityName}' is not a table or view. For stored procedures, use the execute_entity tool instead.", logger); + } + // 5) Authorization after we have a known entity IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); HttpContext? httpContext = httpContextAccessor.HttpContext; diff --git a/src/Azure.DataApiBuilder.Mcp/Core/DynamicCustomTool.cs b/src/Azure.DataApiBuilder.Mcp/Core/DynamicCustomTool.cs index ea2fa0cfea..f724d0d1ba 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/DynamicCustomTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/DynamicCustomTool.cs @@ -38,7 +38,6 @@ namespace Azure.DataApiBuilder.Mcp.Core /// public class DynamicCustomTool : IMcpTool { - private readonly string _entityName; private readonly Entity _entity; /// @@ -48,7 +47,7 @@ public class DynamicCustomTool : IMcpTool /// The entity configuration object. public DynamicCustomTool(string entityName, Entity entity) { - _entityName = entityName ?? throw new ArgumentNullException(nameof(entityName)); + EntityName = entityName ?? throw new ArgumentNullException(nameof(entityName)); _entity = entity ?? throw new ArgumentNullException(nameof(entity)); // Validate that this is a stored procedure @@ -65,12 +64,17 @@ public DynamicCustomTool(string entityName, Entity entity) /// public ToolType ToolType { get; } = ToolType.Custom; + /// + /// Gets the entity name associated with this custom tool. + /// + public string EntityName { get; } + /// /// Gets the metadata for this custom tool, including name, description, and input schema. /// public Tool GetToolMetadata() { - string toolName = ConvertToToolName(_entityName); + string toolName = ConvertToToolName(EntityName); string description = _entity.Description ?? $"Executes the {toolName} stored procedure"; // Build input schema based on parameters @@ -114,25 +118,25 @@ public async Task ExecuteAsync( } // 3) Validate entity still exists in configuration - if (!config.Entities.TryGetValue(_entityName, out Entity? entityConfig)) + if (!config.Entities.TryGetValue(EntityName, out Entity? entityConfig)) { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{_entityName}' not found in configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{EntityName}' not found in configuration.", logger); } if (entityConfig.Source.Type != EntitySourceType.StoredProcedure) { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity {_entityName} is not a stored procedure.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity {EntityName} is not a stored procedure.", logger); } // Check if custom tool is still enabled for this entity if (entityConfig.Mcp?.CustomToolEnabled != true) { - return McpErrorHelpers.ToolDisabled(toolName, logger, $"Custom tool is disabled for entity '{_entityName}'."); + return McpErrorHelpers.ToolDisabled(toolName, logger, $"Custom tool is disabled for entity '{EntityName}'."); } // 4) Resolve metadata if (!McpMetadataHelper.TryResolveMetadata( - _entityName, + EntityName, config, serviceProvider, out ISqlMetadataProvider sqlMetadataProvider, @@ -150,18 +154,18 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpErrorHelpers.PermissionDenied(toolName, _entityName, "execute", roleError, logger); + return McpErrorHelpers.PermissionDenied(toolName, EntityName, "execute", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( httpContext!, authResolver, - _entityName, + EntityName, EntityActionOperation.Execute, out string? effectiveRole, out string authError)) { - return McpErrorHelpers.PermissionDenied(toolName, _entityName, "execute", authError, logger); + return McpErrorHelpers.PermissionDenied(toolName, EntityName, "execute", authError, logger); } // 6) Build request payload @@ -175,7 +179,7 @@ public async Task ExecuteAsync( // 7) Build stored procedure execution context StoredProcedureRequestContext context = new( - entityName: _entityName, + entityName: EntityName, dbo: dbObject, requestPayloadRoot: requestPayloadRoot, operationType: EntityActionOperation.Execute); @@ -218,7 +222,7 @@ public async Task ExecuteAsync( } catch (DataApiBuilderException dabEx) { - logger?.LogError(dabEx, "Error executing custom tool {ToolName} for entity {Entity}", toolName, _entityName); + logger?.LogError(dabEx, "Error executing custom tool {ToolName} for entity {Entity}", toolName, EntityName); return McpResponseBuilder.BuildErrorResult(toolName, "ExecutionError", dabEx.Message, logger); } catch (SqlException sqlEx) @@ -238,7 +242,7 @@ public async Task ExecuteAsync( } // 9) Build success response - return BuildExecuteSuccessResponse(toolName, _entityName, parameters, queryResult, logger); + return BuildExecuteSuccessResponse(toolName, EntityName, parameters, queryResult, logger); } catch (OperationCanceledException) { @@ -246,7 +250,7 @@ public async Task ExecuteAsync( } catch (Exception ex) { - logger?.LogError(ex, "Unexpected error in DynamicCustomTool for {EntityName}", _entityName); + logger?.LogError(ex, "Unexpected error in DynamicCustomTool for {EntityName}", EntityName); return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "An unexpected error occurred.", logger); } } diff --git a/src/Azure.DataApiBuilder.Mcp/Core/McpServerConfiguration.cs b/src/Azure.DataApiBuilder.Mcp/Core/McpServerConfiguration.cs index d76af816bd..2b48c37a83 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/McpServerConfiguration.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/McpServerConfiguration.cs @@ -3,9 +3,11 @@ using System.Text.Json; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; namespace Azure.DataApiBuilder.Mcp.Core { @@ -15,79 +17,84 @@ namespace Azure.DataApiBuilder.Mcp.Core internal static class McpServerConfiguration { /// - /// Configures the MCP server with tool capabilities + /// Configures the MCP server with tool capabilities. /// internal static IServiceCollection ConfigureMcpServer(this IServiceCollection services) { - services.AddMcpServer(options => + services.AddMcpServer() + .WithListToolsHandler((RequestContext request, CancellationToken ct) => { - options.ServerInfo = new() { Name = McpProtocolDefaults.MCP_SERVER_NAME, Version = McpProtocolDefaults.MCP_SERVER_VERSION }; - options.Capabilities = new() + McpToolRegistry? toolRegistry = request.Services?.GetRequiredService(); + if (toolRegistry == null) { - Tools = new() - { - ListToolsHandler = (request, ct) => - { - McpToolRegistry? toolRegistry = request.Services?.GetRequiredService(); - if (toolRegistry == null) - { - throw new InvalidOperationException("Tool registry is not available."); - } + throw new InvalidOperationException("Tool registry is not available."); + } - List tools = toolRegistry.GetAllTools().ToList(); + List tools = toolRegistry.GetAllTools().ToList(); - return ValueTask.FromResult(new ListToolsResult - { - Tools = tools - }); - }, - CallToolHandler = async (request, ct) => - { - McpToolRegistry? toolRegistry = request.Services?.GetRequiredService(); - if (toolRegistry == null) - { - throw new InvalidOperationException("Tool registry is not available."); - } - - string? toolName = request.Params?.Name; - if (string.IsNullOrEmpty(toolName)) - { - throw new McpException("Tool name is required."); - } + return ValueTask.FromResult(new ListToolsResult + { + Tools = tools + }); + }) + .WithCallToolHandler(async (RequestContext request, CancellationToken ct) => + { + McpToolRegistry? toolRegistry = request.Services?.GetRequiredService(); + if (toolRegistry == null) + { + throw new InvalidOperationException("Tool registry is not available."); + } - if (!toolRegistry.TryGetTool(toolName, out IMcpTool? tool)) - { - throw new McpException($"Unknown tool: '{toolName}'"); - } + string? toolName = request.Params?.Name; + if (string.IsNullOrEmpty(toolName)) + { + throw new McpException("Tool name is required."); + } - JsonDocument? arguments = null; - if (request.Params?.Arguments != null) - { - // Convert IReadOnlyDictionary to JsonDocument - Dictionary jsonObject = new(); - foreach (KeyValuePair kvp in request.Params.Arguments) - { - jsonObject[kvp.Key] = kvp.Value; - } + if (!toolRegistry.TryGetTool(toolName, out IMcpTool? tool)) + { + throw new McpException($"Unknown tool: '{toolName}'"); + } - string json = JsonSerializer.Serialize(jsonObject); - arguments = JsonDocument.Parse(json); - } + if (tool is null || request.Services is null) + { + throw new InvalidOperationException("Tool or service provider unexpectedly null."); + } - try - { - return await tool!.ExecuteAsync(arguments, request.Services!, ct); - } - finally - { - arguments?.Dispose(); - } + JsonDocument? arguments = null; + try + { + if (request.Params?.Arguments != null) + { + // Convert IReadOnlyDictionary to JsonDocument + Dictionary jsonObject = new(); + foreach (KeyValuePair kvp in request.Params.Arguments) + { + jsonObject[kvp.Key] = kvp.Value; } + + string json = JsonSerializer.Serialize(jsonObject); + arguments = JsonDocument.Parse(json); } - }; + + return await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, toolName, arguments, request.Services, ct); + } + finally + { + arguments?.Dispose(); + } }) .WithHttpTransport(); + // Configure underlying MCP server options + services.PostConfigure(options => + { + options.ServerInfo = new() { Name = McpProtocolDefaults.MCP_SERVER_NAME, Version = McpProtocolDefaults.MCP_SERVER_VERSION }; + options.Capabilities ??= new(); + options.Capabilities.Tools ??= new(); + }); + return services; } } diff --git a/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs b/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs index 51d8295068..1ab1c73d05 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs @@ -7,6 +7,7 @@ using Azure.DataApiBuilder.Core.AuthenticationHelpers.AuthenticationSimulator; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -284,7 +285,7 @@ private async Task HandleCallToolAsync(JsonElement? id, JsonElement root, Cancel Console.Error.WriteLine($"[MCP DEBUG] callTool → tool: {toolName}, args: "); } - // Execute the tool. + // Execute the tool with telemetry. // If a MCP stdio role override is set in the environment, create // a request HttpContext with the X-MS-API-ROLE header so tools and authorization // helpers that read IHttpContextAccessor will see the role. We also ensure the @@ -319,7 +320,8 @@ private async Task HandleCallToolAsync(JsonElement? id, JsonElement root, Cancel try { // Execute the tool with the scoped service provider so any scoped services resolve correctly. - callResult = await tool.ExecuteAsync(argsDoc, scopedProvider, ct); + callResult = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, toolName!, argsDoc, scopedProvider, ct); } finally { @@ -332,7 +334,8 @@ private async Task HandleCallToolAsync(JsonElement? id, JsonElement root, Cancel } else { - callResult = await tool.ExecuteAsync(argsDoc, _serviceProvider, ct); + callResult = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, toolName!, argsDoc, _serviceProvider, ct); } // Normalize to MCP content blocks (array). We try to pass through if a 'Content' property exists, diff --git a/src/Azure.DataApiBuilder.Mcp/Core/McpToolRegistry.cs b/src/Azure.DataApiBuilder.Mcp/Core/McpToolRegistry.cs index 9c9b96d72b..626ddc9125 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/McpToolRegistry.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/McpToolRegistry.cs @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Net; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Service.Exceptions; using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; namespace Azure.DataApiBuilder.Mcp.Core { @@ -11,15 +14,50 @@ namespace Azure.DataApiBuilder.Mcp.Core /// public class McpToolRegistry { - private readonly Dictionary _tools = new(); + private readonly Dictionary _tools = new(StringComparer.OrdinalIgnoreCase); /// /// Registers a tool in the registry /// + /// Thrown when tool name is invalid or duplicate public void RegisterTool(IMcpTool tool) { Tool metadata = tool.GetToolMetadata(); - _tools[metadata.Name] = tool; + string toolName = metadata.Name?.Trim() ?? string.Empty; + + // Reject empty or whitespace-only tool names + if (string.IsNullOrWhiteSpace(toolName)) + { + throw new DataApiBuilderException( + message: "MCP tool name cannot be null, empty, or whitespace.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + // Check for duplicate tool names (case-insensitive) + if (_tools.TryGetValue(toolName, out IMcpTool? existingTool)) + { + // If the same tool instance is already registered, skip silently. + // This can happen when both McpToolRegistryInitializer (hosted service) + // and McpStdioHelper register tools during stdio mode startup. + if (ReferenceEquals(existingTool, tool)) + { + return; + } + + string existingToolType = existingTool.ToolType == ToolType.BuiltIn ? "built-in" : "custom"; + string newToolType = tool.ToolType == ToolType.BuiltIn ? "built-in" : "custom"; + + throw new DataApiBuilderException( + message: $"Duplicate MCP tool name '{toolName}' detected. " + + $"A {existingToolType} tool with this name is already registered. " + + $"Cannot register {newToolType} tool with the same name. " + + $"Tool names must be unique across all tool types.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + _tools[toolName] = tool; } /// diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs index 49cacef2c3..401f270f42 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs @@ -34,7 +34,7 @@ public static CallToolResult BuildSuccessResult( { Content = new List { - new TextContentBlock { Type = "text", Text = output } + new TextContentBlock { Text = output } } }; } @@ -67,7 +67,7 @@ public static CallToolResult BuildErrorResult( { Content = new List { - new TextContentBlock { Type = "text", Text = output } + new TextContentBlock { Text = output } }, IsError = true }; diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs new file mode 100644 index 0000000000..f69a26fa5d --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Constants for MCP telemetry error codes. + /// + internal static class McpTelemetryErrorCodes + { + /// + /// Generic execution failure error code. + /// + public const string EXECUTION_FAILED = "ExecutionFailed"; + + /// + /// Authentication failure error code. + /// + public const string AUTHENTICATION_FAILED = "AuthenticationFailed"; + + /// + /// Authorization failure error code. + /// + public const string AUTHORIZATION_FAILED = "AuthorizationFailed"; + + /// + /// Database operation failure error code. + /// + public const string DATABASE_ERROR = "DatabaseError"; + + /// + /// Invalid request or arguments error code. + /// + public const string INVALID_REQUEST = "InvalidRequest"; + + /// + /// Operation cancelled error code. + /// + public const string OPERATION_CANCELLED = "OperationCancelled"; + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs new file mode 100644 index 0000000000..2a60557f8d --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Text.Json; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Telemetry; +using Azure.DataApiBuilder.Mcp.Core; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Utility class for MCP telemetry operations. + /// + internal static class McpTelemetryHelper + { + /// + /// Executes an MCP tool wrapped in an OpenTelemetry activity span. + /// Handles telemetry attribute extraction, success/failure tracking, + /// and exception recording with typed error codes. + /// + /// The MCP tool to execute. + /// The name of the tool being invoked. + /// The parsed JSON arguments for the tool (may be null). + /// The service provider for resolving dependencies. + /// Cancellation token. + /// The result of the tool execution. + public static async Task ExecuteWithTelemetryAsync( + IMcpTool tool, + string toolName, + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken) + { + using Activity? activity = TelemetryTracesHelper.DABActivitySource.StartActivity("mcp.tool.execute"); + + try + { + // Extract telemetry metadata + string? entityName = ExtractEntityNameFromArguments(arguments); + string? operation = InferOperationFromTool(tool, toolName); + string? dbProcedure = null; + + // For custom tools (DynamicCustomTool), extract stored procedure information + if (tool is DynamicCustomTool customTool) + { + (entityName, dbProcedure) = ExtractCustomToolMetadata(customTool, serviceProvider); + } + + // Track the start of MCP tool execution with telemetry + activity?.TrackMcpToolExecutionStarted( + toolName: toolName, + entityName: entityName, + operation: operation, + dbProcedure: dbProcedure); + + // Execute the tool + CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, cancellationToken); + + // Check if the tool returned an error result (tools catch exceptions internally + // and return CallToolResult with IsError=true instead of throwing) + if (result.IsError == true) + { + // Extract error code and message from the result content + (string? errorCode, string? errorMessage) = ExtractErrorFromCallToolResult(result); + activity?.SetStatus(ActivityStatusCode.Error, errorMessage ?? "Tool returned an error result"); + activity?.SetTag("mcp.tool.error", true); + + if (!string.IsNullOrEmpty(errorCode)) + { + activity?.SetTag("error.code", errorCode); + } + + if (!string.IsNullOrEmpty(errorMessage)) + { + activity?.SetTag("error.message", errorMessage); + } + } + else + { + // Track successful completion + activity?.TrackMcpToolExecutionFinished(); + } + + return result; + } + catch (Exception ex) + { + // Track exception in telemetry with specific error code based on exception type + string errorCode = MapExceptionToErrorCode(ex); + activity?.TrackMcpToolExecutionFinishedWithException(ex, errorCode: errorCode); + throw; + } + } + + /// + /// Infers the operation type from the tool instance and name. + /// For built-in tools, maps tool name directly to operation. + /// For custom tools (stored procedures), always returns "execute". + /// + /// The tool instance. + /// The name of the tool. + /// The inferred operation type. + public static string InferOperationFromTool(IMcpTool tool, string toolName) + { + // Custom tools (stored procedures) are always "execute" + if (tool.ToolType == ToolType.Custom) + { + return "execute"; + } + + // Built-in tools: map tool name to operation + return toolName.ToLowerInvariant() switch + { + "read_records" => "read", + "create_record" => "create", + "update_record" => "update", + "delete_record" => "delete", + "describe_entities" => "describe", + "execute_entity" => "execute", + _ => "execute" // Fallback for any unknown built-in tools + }; + } + + /// + /// Extracts error code and message from a CallToolResult's content. + /// MCP tools may return errors as JSON with "code" and "message" properties. + /// + /// The tool result to extract error info from. + /// A tuple of (errorCode, errorMessage). + private static (string? errorCode, string? errorMessage) ExtractErrorFromCallToolResult(CallToolResult result) + { + string? errorCode = null; + string? errorMessage = null; + + if (result.Content != null) + { + foreach (ContentBlock block in result.Content) + { + // Check if this is a text block with JSON error information + if (block is TextContentBlock textBlock && !string.IsNullOrEmpty(textBlock.Text)) + { + try + { + using JsonDocument doc = JsonDocument.Parse(textBlock.Text); + JsonElement root = doc.RootElement; + + if (root.TryGetProperty("code", out JsonElement codeEl)) + { + errorCode = codeEl.GetString(); + } + + if (root.TryGetProperty("message", out JsonElement msgEl)) + { + errorMessage = msgEl.GetString(); + } + + // If we found error info, we can break + if (errorCode != null || errorMessage != null) + { + break; + } + } + catch + { + // Not JSON or doesn't have expected structure, skip + } + } + } + } + + return (errorCode, errorMessage); + } + + /// + /// Maps an exception to a telemetry error code. + /// + /// The exception to map. + /// The corresponding error code string. + public static string MapExceptionToErrorCode(Exception ex) + { + return ex switch + { + OperationCanceledException => McpTelemetryErrorCodes.OPERATION_CANCELLED, + DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthenticationChallenge + => McpTelemetryErrorCodes.AUTHENTICATION_FAILED, + DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthorizationCheckFailed + => McpTelemetryErrorCodes.AUTHORIZATION_FAILED, + UnauthorizedAccessException => McpTelemetryErrorCodes.AUTHORIZATION_FAILED, + System.Data.Common.DbException => McpTelemetryErrorCodes.DATABASE_ERROR, + ArgumentException => McpTelemetryErrorCodes.INVALID_REQUEST, + _ => McpTelemetryErrorCodes.EXECUTION_FAILED + }; + } + + /// + /// Extracts the entity name from parsed tool arguments, if present. + /// + /// The parsed JSON arguments. + /// The entity name, or null if not present. + private static string? ExtractEntityNameFromArguments(JsonDocument? arguments) + { + if (arguments != null && + arguments.RootElement.TryGetProperty("entity", out JsonElement entityEl) && + entityEl.ValueKind == JsonValueKind.String) + { + return entityEl.GetString(); + } + + return null; + } + + /// + /// Extracts metadata from a custom tool for telemetry purposes. + /// Returns best-effort metadata; failures in configuration access must not prevent tool execution. + /// + /// The custom tool instance. + /// The service provider. + /// A tuple containing the entity name and database procedure name. + public static (string? entityName, string? dbProcedure) ExtractCustomToolMetadata(DynamicCustomTool customTool, IServiceProvider serviceProvider) + { + // Access public properties instead of reflection + string? entityName = customTool.EntityName; + + if (entityName == null) + { + return (null, null); + } + + try + { + // Try to get the stored procedure name from the runtime configuration + RuntimeConfigProvider? runtimeConfigProvider = serviceProvider.GetService(); + if (runtimeConfigProvider != null) + { + RuntimeConfig config = runtimeConfigProvider.GetConfig(); + if (config.Entities.TryGetValue(entityName, out Entity? entityConfig)) + { + string? dbProcedure = entityConfig.Source.Object; + return (entityName, dbProcedure); + } + } + } + catch (Exception) + { + // If configuration access fails for any reason (including DataApiBuilderException + // when runtime config isn't set up), fall back to returning only the entity name. + // Telemetry metadata extraction is best-effort and must not prevent tool execution. + } + + return (entityName, null); + } + } +} diff --git a/src/Cli.Tests/AutoConfigSimulateTests.cs b/src/Cli.Tests/AutoConfigSimulateTests.cs new file mode 100644 index 0000000000..ef90a587dd --- /dev/null +++ b/src/Cli.Tests/AutoConfigSimulateTests.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Cli.Tests; + +/// +/// Tests for the auto-config-simulate CLI command. +/// +[TestClass] +public class AutoConfigSimulateTests +{ + /// + /// MSSQL test category constant, matching the value used by Service.Tests to filter integration tests. + /// Run with: dotnet test --filter "TestCategory=MsSql" + /// + private const string MSSQL_CATEGORY = "MsSql"; + + /// + /// Connection string template for integration tests. + /// The @env('MSSQL_SA_PASSWORD') reference is resolved at config load time when + /// TrySimulateAutoentities calls TryLoadConfig with doReplaceEnvVar: true. + /// + private const string MSSQL_CONNECTION_STRING_TEMPLATE = + "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;" + + "Password=@env('MSSQL_SA_PASSWORD');MultipleActiveResultSets=False;Connection Timeout=30;"; + + private IFileSystem? _fileSystem; + private FileSystemRuntimeConfigLoader? _runtimeConfigLoader; + + [TestInitialize] + public void TestInitialize() + { + _fileSystem = FileSystemUtils.ProvisionMockFileSystem(); + _runtimeConfigLoader = new FileSystemRuntimeConfigLoader(_fileSystem); + + ILoggerFactory loggerFactory = TestLoggerSupport.ProvisionLoggerFactory(); + ConfigGenerator.SetLoggerForCliConfigGenerator(loggerFactory.CreateLogger()); + SetCliUtilsLogger(loggerFactory.CreateLogger()); + } + + [TestCleanup] + public void TestCleanup() + { + _fileSystem = null; + _runtimeConfigLoader = null; + } + + /// + /// Tests that the simulate command fails when no autoentities are defined in the config. + /// + [TestMethod] + public void TestSimulateAutoentities_NoAutoentitiesDefined() + { + // Arrange: create an MSSQL config without autoentities + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigSimulateOptions options = new(config: TEST_RUNTIME_CONFIG_FILE); + + // Act + bool success = TrySimulateAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsFalse(success); + } + + /// + /// Integration test: verifies that an autoentities filter matching a known table (dbo.books) + /// produces correct console output containing the filter name, entity name, and database object. + /// Requires a running MSSQL instance with MSSQL_SA_PASSWORD environment variable set. + /// + [TestMethod] + [TestCategory(MSSQL_CATEGORY)] + public void TestSimulateAutoentities_WithMatchingFilter_OutputsToConsole() + { + if (string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MSSQL_SA_PASSWORD"))) + { + Assert.Inconclusive("MSSQL_SA_PASSWORD environment variable not set. Skipping integration test."); + return; + } + + // Arrange: create MSSQL config with autoentities filter for dbo.books + InitOptions initOptions = new( + databaseType: DatabaseType.MSSQL, + connectionString: MSSQL_CONNECTION_STRING_TEMPLATE, + cosmosNoSqlDatabase: null, + cosmosNoSqlContainer: null, + graphQLSchemaPath: null, + setSessionContext: false, + hostMode: HostMode.Development, + corsOrigin: new List(), + authenticationProvider: EasyAuthType.AppService.ToString(), + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions autoConfigOptions = new( + definitionName: "books-filter", + patternsInclude: new[] { "dbo.books" }, + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryConfigureAutoentities(autoConfigOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigSimulateOptions options = new(config: TEST_RUNTIME_CONFIG_FILE); + + // Capture console output + TextWriter originalOut = Console.Out; + using StringWriter consoleOutput = new(); + Console.SetOut(consoleOutput); + bool success; + try + { + success = TrySimulateAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + } + finally + { + Console.SetOut(originalOut); + } + + string output = consoleOutput.ToString(); + + // Assert + Assert.IsTrue(success, "Simulation should succeed when the filter matches tables."); + StringAssert.Contains(output, "books-filter", "Output should contain the filter name."); + StringAssert.Contains(output, "books", "Output should contain the entity name."); + StringAssert.Contains(output, "dbo.books", "Output should contain the database object."); + } + + /// + /// Integration test: verifies that an autoentities filter matching a known table (dbo.books) + /// produces a well-formed CSV file containing the filter name, entity name, and database object. + /// Requires a running MSSQL instance with MSSQL_SA_PASSWORD environment variable set. + /// + [TestMethod] + [TestCategory(MSSQL_CATEGORY)] + public void TestSimulateAutoentities_WithMatchingFilter_WritesToCsvFile() + { + if (string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MSSQL_SA_PASSWORD"))) + { + Assert.Inconclusive("MSSQL_SA_PASSWORD environment variable not set. Skipping integration test."); + return; + } + + // Arrange: create MSSQL config with autoentities filter for dbo.books + InitOptions initOptions = new( + databaseType: DatabaseType.MSSQL, + connectionString: MSSQL_CONNECTION_STRING_TEMPLATE, + cosmosNoSqlDatabase: null, + cosmosNoSqlContainer: null, + graphQLSchemaPath: null, + setSessionContext: false, + hostMode: HostMode.Development, + corsOrigin: new List(), + authenticationProvider: EasyAuthType.AppService.ToString(), + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions autoConfigOptions = new( + definitionName: "books-filter", + patternsInclude: new[] { "dbo.books" }, + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryConfigureAutoentities(autoConfigOptions, _runtimeConfigLoader!, _fileSystem!)); + + string outputCsvPath = "simulation-output.csv"; + AutoConfigSimulateOptions options = new(output: outputCsvPath, config: TEST_RUNTIME_CONFIG_FILE); + + // Act + bool success = TrySimulateAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success, "Simulation should succeed when the filter matches tables."); + Assert.IsTrue(_fileSystem!.File.Exists(outputCsvPath), "CSV output file should be created."); + string csvContent = _fileSystem.File.ReadAllText(outputCsvPath); + StringAssert.Contains(csvContent, "filter_name,entity_name,database_object", "CSV should have a header row."); + StringAssert.Contains(csvContent, "books-filter", "CSV should contain the filter name."); + StringAssert.Contains(csvContent, "books", "CSV should contain the entity name."); + StringAssert.Contains(csvContent, "dbo.books", "CSV should contain the database object."); + } + + /// + /// Integration test: verifies that an autoentities filter matching no tables returns success + /// and prints a "(no matches)" message to the console. + /// Requires a running MSSQL instance with MSSQL_SA_PASSWORD environment variable set. + /// + [TestMethod] + [TestCategory(MSSQL_CATEGORY)] + public void TestSimulateAutoentities_WithNonMatchingFilter_OutputsNoMatches() + { + if (string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MSSQL_SA_PASSWORD"))) + { + Assert.Inconclusive("MSSQL_SA_PASSWORD environment variable not set. Skipping integration test."); + return; + } + + // Arrange: create MSSQL config with autoentities filter that matches no tables + InitOptions initOptions = new( + databaseType: DatabaseType.MSSQL, + connectionString: MSSQL_CONNECTION_STRING_TEMPLATE, + cosmosNoSqlDatabase: null, + cosmosNoSqlContainer: null, + graphQLSchemaPath: null, + setSessionContext: false, + hostMode: HostMode.Development, + corsOrigin: new List(), + authenticationProvider: EasyAuthType.AppService.ToString(), + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions autoConfigOptions = new( + definitionName: "empty-filter", + patternsInclude: new[] { "dbo.NonExistentTable99999" }, + config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryConfigureAutoentities(autoConfigOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigSimulateOptions options = new(config: TEST_RUNTIME_CONFIG_FILE); + + // Capture console output + TextWriter originalOut = Console.Out; + using StringWriter consoleOutput = new(); + Console.SetOut(consoleOutput); + bool success; + try + { + success = TrySimulateAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + } + finally + { + Console.SetOut(originalOut); + } + + string output = consoleOutput.ToString(); + + // Assert + // Output format is produced by WriteSimulationResultsToConsole: + // "Filter: ", "Matches: ", and "(no matches)" when count is 0. + Assert.IsTrue(success, "Simulation should succeed even when no tables match."); + StringAssert.Contains(output, "empty-filter", "Output should contain the filter name."); + StringAssert.Contains(output, "Matches: 0", "Output should show zero matches."); + StringAssert.Contains(output, "(no matches)", "Output should show the 'no matches' message."); + } +} diff --git a/src/Cli.Tests/AutoConfigTests.cs b/src/Cli.Tests/AutoConfigTests.cs new file mode 100644 index 0000000000..40e3a461f7 --- /dev/null +++ b/src/Cli.Tests/AutoConfigTests.cs @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Cli.Tests; + +/// +/// Tests for the auto-config CLI command. +/// +[TestClass] +public class AutoConfigTests +{ + private IFileSystem? _fileSystem; + private FileSystemRuntimeConfigLoader? _runtimeConfigLoader; + + [TestInitialize] + public void TestInitialize() + { + _fileSystem = FileSystemUtils.ProvisionMockFileSystem(); + _runtimeConfigLoader = new FileSystemRuntimeConfigLoader(_fileSystem); + + ILoggerFactory loggerFactory = TestLoggerSupport.ProvisionLoggerFactory(); + ConfigGenerator.SetLoggerForCliConfigGenerator(loggerFactory.CreateLogger()); + SetCliUtilsLogger(loggerFactory.CreateLogger()); + } + + [TestCleanup] + public void TestCleanup() + { + _fileSystem = null; + _runtimeConfigLoader = null; + } + + /// + /// Tests that a new autoentities definition is successfully created with patterns. + /// + [TestMethod] + public void TestCreateAutoentitiesDefinition_WithPatterns() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions options = new( + definitionName: "test-def", + patternsInclude: new[] { "dbo.%", "sys.%" }, + patternsExclude: new[] { "dbo.internal%" }, + patternsName: "{schema}_{table}", + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success); + Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig(TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? config)); + Assert.IsNotNull(config.Autoentities); + Assert.IsTrue(config.Autoentities.Autoentities.ContainsKey("test-def")); + + Autoentity autoentity = config.Autoentities.Autoentities["test-def"]; + Assert.AreEqual(2, autoentity.Patterns.Include.Length); + Assert.AreEqual("dbo.%", autoentity.Patterns.Include[0]); + Assert.AreEqual("sys.%", autoentity.Patterns.Include[1]); + Assert.AreEqual(1, autoentity.Patterns.Exclude.Length); + Assert.AreEqual("dbo.internal%", autoentity.Patterns.Exclude[0]); + Assert.AreEqual("{schema}_{table}", autoentity.Patterns.Name); + } + + /// + /// Tests that template options are correctly configured for an autoentities definition. + /// + [TestMethod] + public void TestConfigureAutoentitiesDefinition_WithTemplateOptions() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions options = new( + definitionName: "test-def", + templateRestEnabled: true, + templateGraphqlEnabled: false, + templateMcpDmlTool: "true", + templateCacheEnabled: true, + templateCacheTtlSeconds: 30, + templateCacheLevel: "L1", + templateHealthEnabled: true, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success); + Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig(TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? config)); + + Autoentity autoentity = config.Autoentities!.Autoentities["test-def"]; + Assert.IsTrue(autoentity.Template.Rest.Enabled); + Assert.IsFalse(autoentity.Template.GraphQL.Enabled); + Assert.IsTrue(autoentity.Template.Mcp!.DmlToolEnabled); + Assert.AreEqual(true, autoentity.Template.Cache.Enabled); + Assert.AreEqual(30, autoentity.Template.Cache.TtlSeconds); + Assert.AreEqual(EntityCacheLevel.L1, autoentity.Template.Cache.Level); + Assert.IsTrue(autoentity.Template.Health.Enabled); + } + + /// + /// Tests that an existing autoentities definition is successfully updated. + /// + [TestMethod] + public void TestUpdateExistingAutoentitiesDefinition() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + // Create initial definition + AutoConfigOptions initialOptions = new( + definitionName: "test-def", + patternsInclude: new[] { "dbo.%" }, + templateCacheTtlSeconds: 10, + permissions: new[] { "anonymous", "read" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + Assert.IsTrue(ConfigGenerator.TryConfigureAutoentities(initialOptions, _runtimeConfigLoader!, _fileSystem!)); + + // Update definition + AutoConfigOptions updateOptions = new( + definitionName: "test-def", + patternsExclude: new[] { "dbo.internal%" }, + templateCacheTtlSeconds: 60, + permissions: new[] { "authenticated", "create,read,update,delete" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(updateOptions, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success); + Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig(TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? config)); + + Autoentity autoentity = config.Autoentities!.Autoentities["test-def"]; + // Include should remain from initial setup + Assert.AreEqual(1, autoentity.Patterns.Include.Length); + Assert.AreEqual("dbo.%", autoentity.Patterns.Include[0]); + // Exclude should be added + Assert.AreEqual(1, autoentity.Patterns.Exclude.Length); + Assert.AreEqual("dbo.internal%", autoentity.Patterns.Exclude[0]); + // Cache TTL should be updated + Assert.AreEqual(60, autoentity.Template.Cache.TtlSeconds); + // Permissions should be replaced + Assert.AreEqual(1, autoentity.Permissions.Length); + Assert.AreEqual("authenticated", autoentity.Permissions[0].Role); + } + + /// + /// Tests that permissions are correctly parsed and applied. + /// + [TestMethod] + public void TestConfigureAutoentitiesDefinition_WithMultipleActions() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions options = new( + definitionName: "test-def", + permissions: new[] { "authenticated", "create,read,update,delete" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success); + Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig(TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? config)); + + Autoentity autoentity = config.Autoentities!.Autoentities["test-def"]; + Assert.AreEqual(1, autoentity.Permissions.Length); + Assert.AreEqual("authenticated", autoentity.Permissions[0].Role); + Assert.AreEqual(4, autoentity.Permissions[0].Actions.Length); + } + + /// + /// Tests that invalid MCP dml-tool value is handled correctly. + /// + [TestMethod] + public void TestConfigureAutoentitiesDefinition_InvalidMcpDmlTool() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions options = new( + definitionName: "test-def", + templateMcpDmlTool: "invalid-value", + permissions: new[] { "anonymous", "read" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert - Should fail due to invalid MCP value + Assert.IsFalse(success); + } + + /// + /// Tests that invalid cache level value is handled correctly. + /// + [TestMethod] + public void TestConfigureAutoentitiesDefinition_InvalidCacheLevel() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + AutoConfigOptions options = new( + definitionName: "test-def", + templateCacheLevel: "InvalidLevel", + permissions: new[] { "anonymous", "read" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert - Should fail due to invalid cache level + Assert.IsFalse(success); + } + + /// + /// Tests that multiple autoentities definitions can coexist. + /// + [TestMethod] + public void TestMultipleAutoentitiesDefinitions() + { + // Arrange + InitOptions initOptions = CreateBasicInitOptionsForMsSqlWithConfig(config: TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(ConfigGenerator.TryGenerateConfig(initOptions, _runtimeConfigLoader!, _fileSystem!)); + + // Create first definition + AutoConfigOptions options1 = new( + definitionName: "def-1", + patternsInclude: new[] { "dbo.%" }, + permissions: new[] { "anonymous", "read" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + Assert.IsTrue(ConfigGenerator.TryConfigureAutoentities(options1, _runtimeConfigLoader!, _fileSystem!)); + + // Create second definition + AutoConfigOptions options2 = new( + definitionName: "def-2", + patternsInclude: new[] { "sys.%" }, + permissions: new[] { "authenticated", "*" }, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options2, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(success); + Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig(TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? config)); + Assert.AreEqual(2, config.Autoentities!.Autoentities.Count); + Assert.IsTrue(config.Autoentities.Autoentities.ContainsKey("def-1")); + Assert.IsTrue(config.Autoentities.Autoentities.ContainsKey("def-2")); + } + + /// + /// Tests that attempting to configure autoentities without a config file fails. + /// + [TestMethod] + public void TestConfigureAutoentitiesDefinition_NoConfigFile() + { + // Arrange + AutoConfigOptions options = new( + definitionName: "test-def", + permissions: new[] { "anonymous", "read" } + ); + + // Act + bool success = ConfigGenerator.TryConfigureAutoentities(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsFalse(success); + } +} diff --git a/src/Cli.Tests/ConfigGeneratorTests.cs b/src/Cli.Tests/ConfigGeneratorTests.cs index 59a7f7b8dd..9621a8ce98 100644 --- a/src/Cli.Tests/ConfigGeneratorTests.cs +++ b/src/Cli.Tests/ConfigGeneratorTests.cs @@ -178,6 +178,7 @@ public void TestSpecialCharactersInConnectionString() ""mode"": ""production"" } }, + ""autoentities"": {}, ""entities"": {} }"); diff --git a/src/Cli.Tests/ConfigureOptionsTests.cs b/src/Cli.Tests/ConfigureOptionsTests.cs index 4dad501fda..a1355679fb 100644 --- a/src/Cli.Tests/ConfigureOptionsTests.cs +++ b/src/Cli.Tests/ConfigureOptionsTests.cs @@ -14,6 +14,7 @@ public class ConfigureOptionsTests : VerifyBase private MockFileSystem? _fileSystem; private FileSystemRuntimeConfigLoader? _runtimeConfigLoader; private const string TEST_RUNTIME_CONFIG_FILE = "test-update-runtime-setting.json"; + private const string TEST_DATASOURCE_HEALTH_NAME = "My Data Source"; [TestInitialize] public void TestInitialize() @@ -540,6 +541,34 @@ public void TestUpdateTTLForCacheSettings(int updatedTtlValue) Assert.AreEqual(updatedTtlValue, runtimeConfig.Runtime.Cache.TtlSeconds); } + /// + /// Tests that running "dab configure --runtime.compression.level {value}" on a config with various values results + /// in runtime config update. Takes in updated value for compression.level and + /// validates whether the runtime config reflects those updated values. + [DataTestMethod] + [DataRow(CompressionLevel.Fastest, DisplayName = "Update Compression.Level to fastest.")] + [DataRow(CompressionLevel.Optimal, DisplayName = "Update Compression.Level to optimal.")] + [DataRow(CompressionLevel.None, DisplayName = "Update Compression.Level to none.")] + public void TestUpdateLevelForCompressionSettings(CompressionLevel updatedLevelValue) + { + // Arrange -> all the setup which includes creating options. + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + + // Act: Attempts to update compression level value + ConfigureOptions options = new( + runtimeCompressionLevel: updatedLevelValue, + config: TEST_RUNTIME_CONFIG_FILE + ); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Validate the Level Value is updated + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); + Assert.IsNotNull(runtimeConfig.Runtime?.Compression?.Level); + Assert.AreEqual(updatedLevelValue, runtimeConfig.Runtime.Compression.Level); + } + /// /// Tests that running "dab configure --runtime.host.mode {value}" on a config with various values results /// in runtime config update. Takes in updated value for host.mode and @@ -927,6 +956,102 @@ public void TestFailureWhenAddingSetSessionContextToMySQLDatabase() } /// + /// Tests adding data-source.health.name to a config that doesn't have a health section. + /// This method verifies that the health.name can be added to a data source configuration + /// that doesn't previously have a health section. + /// Command: dab configure --data-source.health.name "My Data Source" + /// + [TestMethod] + public void TestAddDataSourceHealthName() + { + // Arrange + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + + ConfigureOptions options = new( + dataSourceHealthName: TEST_DATASOURCE_HEALTH_NAME, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? config)); + Assert.IsNotNull(config.DataSource); + Assert.IsNotNull(config.DataSource.Health); + Assert.AreEqual(TEST_DATASOURCE_HEALTH_NAME, config.DataSource.Health.Name); + Assert.IsTrue(config.DataSource.Health.Enabled); // Default value + } + + /// + /// Tests updating data-source.health.name on a config that already has a health section. + /// This method verifies that the health.name can be updated while preserving other health settings. + /// Command: dab configure --data-source.health.name "Updated Name" + /// + [DataTestMethod] + [DataRow("New Name", DisplayName = "Update health name with a simple string")] + [DataRow("This is the value", DisplayName = "Update health name with the example from the issue")] + public void TestUpdateDataSourceHealthName(string healthName) + { + // Arrange - Config with existing health section + string configWithHealth = @" + { + ""$schema"": ""test"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""testconnectionstring"", + ""health"": { + ""enabled"": false, + ""threshold-ms"": 2000 + } + }, + ""runtime"": { + ""rest"": { + ""enabled"": true, + ""path"": ""/api"" + }, + ""graphql"": { + ""enabled"": true, + ""path"": ""/graphql"", + ""allow-introspection"": true + }, + ""host"": { + ""mode"": ""development"", + ""cors"": { + ""origins"": [], + ""allow-credentials"": false + }, + ""authentication"": { + ""provider"": ""StaticWebApps"" + } + } + }, + ""entities"": {} + }"; + SetupFileSystemWithInitialConfig(configWithHealth); + + ConfigureOptions options = new( + dataSourceHealthName: healthName, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? config)); + Assert.IsNotNull(config.DataSource); + Assert.IsNotNull(config.DataSource.Health); + Assert.AreEqual(healthName, config.DataSource.Health.Name); + // Verify existing health settings are preserved + Assert.IsFalse(config.DataSource.Health.Enabled); + Assert.AreEqual(2000, config.DataSource.Health.ThresholdMs); + } + /// Tests that running "dab configure --runtime.mcp.description {value}" on a config with various values results /// in runtime config update. Takes in updated value for mcp.description and /// validates whether the runtime config reflects those updated values @@ -969,5 +1094,130 @@ private void SetupFileSystemWithInitialConfig(string jsonConfig) Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig? config)); Assert.IsNotNull(config.Runtime); } + + /// + /// Tests adding user-delegated-auth configuration options individually or together. + /// Verifies that enabled and database-audience properties can be set independently or combined. + /// Also verifies default values for properties not explicitly set. + /// Commands: + /// - dab configure --data-source.user-delegated-auth.enabled true + /// - dab configure --data-source.user-delegated-auth.database-audience "https://database.windows.net" + /// - dab configure --data-source.user-delegated-auth.enabled true --data-source.user-delegated-auth.database-audience "https://database.windows.net" + /// + [DataTestMethod] + [DataRow(true, null, DisplayName = "Set enabled=true only")] + [DataRow(null, "https://database.windows.net", DisplayName = "Set database-audience only")] + [DataRow(true, "https://database.windows.net", DisplayName = "Set both enabled and database-audience")] + public void TestAddUserDelegatedAuthConfiguration(bool? enabledValue, string? audienceValue) + { + // Arrange + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + + ConfigureOptions options = new( + dataSourceUserDelegatedAuthEnabled: enabledValue, + dataSourceUserDelegatedAuthDatabaseAudience: audienceValue, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? config)); + Assert.IsNotNull(config.DataSource); + Assert.IsNotNull(config.DataSource.UserDelegatedAuth); + + // Verify enabled value (if set, use provided value; otherwise defaults to false) + if (enabledValue.HasValue) + { + Assert.AreEqual(enabledValue.Value, config.DataSource.UserDelegatedAuth.Enabled); + } + else + { + Assert.IsFalse(config.DataSource.UserDelegatedAuth.Enabled); + } + + // Verify database-audience value + if (audienceValue is not null) + { + Assert.AreEqual(audienceValue, config.DataSource.UserDelegatedAuth.DatabaseAudience); + } + else + { + Assert.IsNull(config.DataSource.UserDelegatedAuth.DatabaseAudience); + } + + // Verify provider is set to default + Assert.AreEqual("EntraId", config.DataSource.UserDelegatedAuth.Provider); + } + + /// + /// Tests that enabling user-delegated-auth on a non-MSSQL database fails. + /// This method verifies that user-delegated-auth is only allowed for MSSQL database type. + /// Command: dab configure --data-source.database-type postgresql --data-source.user-delegated-auth.enabled true + /// + [DataTestMethod] + [DataRow("postgresql", DisplayName = "Fail when enabling user-delegated-auth on PostgreSQL")] + [DataRow("mysql", DisplayName = "Fail when enabling user-delegated-auth on MySQL")] + [DataRow("cosmosdb_nosql", DisplayName = "Fail when enabling user-delegated-auth on CosmosDB")] + public void TestFailureWhenEnablingUserDelegatedAuthOnNonMSSQLDatabase(string dbType) + { + // Arrange + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + + ConfigureOptions options = new( + dataSourceDatabaseType: dbType, + dataSourceUserDelegatedAuthEnabled: true, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsFalse(isSuccess); + } + + /// + /// Tests updating existing user-delegated-auth configuration by changing the database-audience. + /// Verifies that the database-audience can be updated while preserving the enabled setting. + /// Also validates JSON structure: verifies user-delegated-auth is correctly nested under data-source + /// with proper JSON property names (enabled, provider, database-audience). + /// + [TestMethod] + public void TestUpdateUserDelegatedAuthDatabaseAudience() + { + // Arrange - Config with existing user-delegated-auth section + SetupFileSystemWithInitialConfig(TestHelper.CONFIG_WITH_USER_DELEGATED_AUTH); + + string newAudience = "https://database.usgovcloudapi.net"; + ConfigureOptions options = new( + dataSourceUserDelegatedAuthDatabaseAudience: newAudience, + config: TEST_RUNTIME_CONFIG_FILE + ); + + // Act + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? config)); + Assert.IsNotNull(config.DataSource); + Assert.IsNotNull(config.DataSource.UserDelegatedAuth); + Assert.IsTrue(config.DataSource.UserDelegatedAuth.Enabled); + Assert.AreEqual(newAudience, config.DataSource.UserDelegatedAuth.DatabaseAudience); + Assert.AreEqual("EntraId", config.DataSource.UserDelegatedAuth.Provider); + + // Verify JSON structure using JObject to ensure correct nesting + JObject configJson = JObject.Parse(updatedConfig); + JToken? userDelegatedAuthSection = configJson["data-source"]?["user-delegated-auth"]; + Assert.IsNotNull(userDelegatedAuthSection); + Assert.AreEqual(newAudience, (string?)userDelegatedAuthSection["database-audience"]); + Assert.AreEqual(true, (bool?)userDelegatedAuthSection["enabled"]); + Assert.AreEqual("EntraId", (string?)userDelegatedAuthSection["provider"]); + } } } diff --git a/src/Cli.Tests/ModuleInitializer.cs b/src/Cli.Tests/ModuleInitializer.cs index a03dcddd10..2baf75a780 100644 --- a/src/Cli.Tests/ModuleInitializer.cs +++ b/src/Cli.Tests/ModuleInitializer.cs @@ -23,6 +23,8 @@ public static void Init() VerifierSettings.IgnoreMember(dataSource => dataSource.IsDatasourceHealthEnabled); // Ignore the DatasourceThresholdMs from the output to avoid committing it. VerifierSettings.IgnoreMember(dataSource => dataSource.DatasourceThresholdMs); + // Ignore the IsUserDelegatedAuthEnabled from the output as it's a computed property. + VerifierSettings.IgnoreMember(dataSource => dataSource.IsUserDelegatedAuthEnabled); // Ignore the datasource files as that's unimportant from a test standpoint. VerifierSettings.IgnoreMember(config => config.DataSourceFiles); // Ignore the CosmosDataSourceUsed as that's unimportant from a test standpoint. @@ -105,6 +107,8 @@ public static void Init() VerifierSettings.IgnoreMember(options => options.FeatureFlags); // Ignore the JSON schema path as that's unimportant from a test standpoint. VerifierSettings.IgnoreMember(config => config.Schema); + // Ignore the JSON schema path as that's unimportant from a test standpoint. + VerifierSettings.IgnoreMember(config => config.Autoentities); // Ignore the message as that's not serialized in our config file anyway. VerifierSettings.IgnoreMember(dataSource => dataSource.DatabaseTypeNotSupportedMessage); // Ignore DefaultDataSourceName as that's not serialized in our config file. diff --git a/src/Cli.Tests/TestHelper.cs b/src/Cli.Tests/TestHelper.cs index 8224a079d4..a75e359ee4 100644 --- a/src/Cli.Tests/TestHelper.cs +++ b/src/Cli.Tests/TestHelper.cs @@ -279,6 +279,46 @@ public static Process ExecuteDabCommand(string command, string flags) public const string CONFIG_WITH_DISABLED_GLOBAL_REST_GRAPHQL = $"{{{SAMPLE_SCHEMA_DATA_SOURCE},{RUNTIME_SECTION_WITH_DISABLED_REST_GRAPHQL}}}"; + /// + /// A config json with user-delegated-auth enabled. This is used in tests to verify updating existing + /// user-delegated-auth configuration. + /// + public const string CONFIG_WITH_USER_DELEGATED_AUTH = @" + { + ""$schema"": """ + DAB_DRAFT_SCHEMA_TEST_PATH + @""", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": """ + SAMPLE_TEST_CONN_STRING + @""", + ""user-delegated-auth"": { + ""enabled"": true, + ""provider"": ""EntraId"", + ""database-audience"": ""https://database.windows.net"" + } + }, + ""runtime"": { + ""rest"": { + ""enabled"": true, + ""path"": ""/api"" + }, + ""graphql"": { + ""enabled"": true, + ""path"": ""/graphql"", + ""allow-introspection"": true + }, + ""host"": { + ""mode"": ""development"", + ""cors"": { + ""origins"": [], + ""allow-credentials"": false + }, + ""authentication"": { + ""provider"": ""StaticWebApps"" + } + } + }, + ""entities"": {} + }"; + public const string SINGLE_ENTITY = @" { ""entities"": { @@ -1302,6 +1342,7 @@ public static string GenerateConfigWithGivenDepthLimit(string? depthLimitJson = }} }} }}, + ""autoentities"": {{}}, ""entities"": {{}}"; return $"{{{SAMPLE_SCHEMA_DATA_SOURCE},{runtimeSection}}}"; diff --git a/src/Cli.Tests/UserDelegatedAuthRuntimeParsingTests.cs b/src/Cli.Tests/UserDelegatedAuthRuntimeParsingTests.cs new file mode 100644 index 0000000000..29110a5a7c --- /dev/null +++ b/src/Cli.Tests/UserDelegatedAuthRuntimeParsingTests.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Cli.Tests +{ + [TestClass] + public class UserDelegatedAuthRuntimeParsingTests + { + [TestMethod] + public void TestRuntimeCanParseUserDelegatedAuthConfig() + { + // Arrange + string configJson = @"{ + ""$schema"": ""test"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""testconnectionstring"", + ""user-delegated-auth"": { + ""enabled"": true, + ""database-audience"": ""https://database.windows.net"" + } + }, + ""runtime"": { + ""rest"": { + ""enabled"": true, + ""path"": ""/api"" + }, + ""graphql"": { + ""enabled"": true, + ""path"": ""/graphql"", + ""allow-introspection"": true + }, + ""host"": { + ""mode"": ""development"", + ""cors"": { + ""origins"": [], + ""allow-credentials"": false + }, + ""authentication"": { + ""provider"": ""StaticWebApps"" + } + } + }, + ""entities"": {} + }"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig? config); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(config); + Assert.IsNotNull(config.DataSource.UserDelegatedAuth); + Assert.IsTrue(config.DataSource.UserDelegatedAuth.Enabled); + Assert.AreEqual("https://database.windows.net", config.DataSource.UserDelegatedAuth.DatabaseAudience); + } + + [TestMethod] + public void TestRuntimeCanParseConfigWithoutUserDelegatedAuth() + { + // Arrange + string configJson = @"{ + ""$schema"": ""test"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""testconnectionstring"" + }, + ""runtime"": { + ""rest"": { + ""enabled"": true, + ""path"": ""/api"" + }, + ""graphql"": { + ""enabled"": true, + ""path"": ""/graphql"", + ""allow-introspection"": true + }, + ""host"": { + ""mode"": ""development"", + ""cors"": { + ""origins"": [], + ""allow-credentials"": false + }, + ""authentication"": { + ""provider"": ""StaticWebApps"" + } + } + }, + ""entities"": {} + }"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig? config); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(config); + Assert.IsNull(config.DataSource.UserDelegatedAuth); + } + } +} diff --git a/src/Cli/Commands/AutoConfigOptions.cs b/src/Cli/Commands/AutoConfigOptions.cs new file mode 100644 index 0000000000..41227cd03a --- /dev/null +++ b/src/Cli/Commands/AutoConfigOptions.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.IO.Abstractions; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Product; +using Cli.Constants; +using CommandLine; +using Microsoft.Extensions.Logging; +using static Cli.Utils; +using ILogger = Microsoft.Extensions.Logging.ILogger; + +namespace Cli.Commands +{ + /// + /// AutoConfigOptions command options + /// This command will be used to configure autoentities definitions in the config file. + /// + [Verb("auto-config", isDefault: false, HelpText = "Configure autoentities definitions", Hidden = false)] + public class AutoConfigOptions : Options + { + public AutoConfigOptions( + string definitionName, + IEnumerable? patternsInclude = null, + IEnumerable? patternsExclude = null, + string? patternsName = null, + string? templateMcpDmlTool = null, + bool? templateRestEnabled = null, + bool? templateGraphqlEnabled = null, + bool? templateCacheEnabled = null, + int? templateCacheTtlSeconds = null, + string? templateCacheLevel = null, + bool? templateHealthEnabled = null, + IEnumerable? permissions = null, + string? config = null) + : base(config) + { + DefinitionName = definitionName; + PatternsInclude = patternsInclude; + PatternsExclude = patternsExclude; + PatternsName = patternsName; + TemplateMcpDmlTool = templateMcpDmlTool; + TemplateRestEnabled = templateRestEnabled; + TemplateGraphqlEnabled = templateGraphqlEnabled; + TemplateCacheEnabled = templateCacheEnabled; + TemplateCacheTtlSeconds = templateCacheTtlSeconds; + TemplateCacheLevel = templateCacheLevel; + TemplateHealthEnabled = templateHealthEnabled; + Permissions = permissions; + } + + [Value(0, Required = true, HelpText = "Name of the autoentities definition to configure.")] + public string DefinitionName { get; } + + [Option("patterns.include", Required = false, HelpText = "T-SQL LIKE pattern(s) to include database objects. Space-separated array of patterns. Default: '%.%'.")] + public IEnumerable? PatternsInclude { get; } + + [Option("patterns.exclude", Required = false, HelpText = "T-SQL LIKE pattern(s) to exclude database objects. Space-separated array of patterns. Default: null")] + public IEnumerable? PatternsExclude { get; } + + [Option("patterns.name", Required = false, HelpText = "Interpolation syntax for entity naming (must be unique for each generated entity). Default: '{object}'")] + public string? PatternsName { get; } + + [Option("template.mcp.dml-tool", Required = false, HelpText = "Enable/disable DML tools for generated entities. Allowed values: true, false. Default: true")] + public string? TemplateMcpDmlTool { get; } + + [Option("template.rest.enabled", Required = false, HelpText = "Enable/disable REST endpoint for generated entities. Allowed values: true, false. Default: true")] + public bool? TemplateRestEnabled { get; } + + [Option("template.graphql.enabled", Required = false, HelpText = "Enable/disable GraphQL endpoint for generated entities. Allowed values: true, false. Default: true")] + public bool? TemplateGraphqlEnabled { get; } + + [Option("template.cache.enabled", Required = false, HelpText = "Enable/disable cache for generated entities. Allowed values: true, false. Default: false")] + public bool? TemplateCacheEnabled { get; } + + [Option("template.cache.ttl-seconds", Required = false, HelpText = "Cache time-to-live in seconds for generated entities. Default: null")] + public int? TemplateCacheTtlSeconds { get; } + + [Option("template.cache.level", Required = false, HelpText = "Cache level for generated entities. Allowed values: L1, L1L2. Default: L1L2")] + public string? TemplateCacheLevel { get; } + + [Option("template.health.enabled", Required = false, HelpText = "Enable/disable health check for generated entities. Allowed values: true, false. Default: true")] + public bool? TemplateHealthEnabled { get; } + + [Option("permissions", Required = false, Separator = ':', HelpText = "Permissions for generated entities in the format role:actions (e.g., anonymous:read). Default: null")] + public IEnumerable? Permissions { get; } + + public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) + { + logger.LogInformation("{productName} {version}", PRODUCT_NAME, ProductInfo.GetProductVersion()); + bool isSuccess = ConfigGenerator.TryConfigureAutoentities(this, loader, fileSystem); + if (isSuccess) + { + logger.LogInformation("Successfully configured autoentities definition: {DefinitionName}.", DefinitionName); + return CliReturnCode.SUCCESS; + } + else + { + logger.LogError("Failed to configure autoentities definition: {DefinitionName}.", DefinitionName); + return CliReturnCode.GENERAL_ERROR; + } + } + } +} diff --git a/src/Cli/Commands/AutoConfigSimulateOptions.cs b/src/Cli/Commands/AutoConfigSimulateOptions.cs new file mode 100644 index 0000000000..205c8d7d9e --- /dev/null +++ b/src/Cli/Commands/AutoConfigSimulateOptions.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.IO.Abstractions; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Product; +using Cli.Constants; +using CommandLine; +using Microsoft.Extensions.Logging; +using static Cli.Utils; +using ILogger = Microsoft.Extensions.Logging.ILogger; + +namespace Cli.Commands +{ + /// + /// Command options for the auto-config-simulate verb. + /// Simulates autoentities generation by querying the database and displaying + /// which entities would be created for each filter definition. + /// + [Verb("auto-config-simulate", isDefault: false, HelpText = "Simulate autoentities generation by querying the database and displaying the results.", Hidden = false)] + public class AutoConfigSimulateOptions : Options + { + public AutoConfigSimulateOptions( + string? output = null, + string? config = null) + : base(config) + { + Output = output; + } + + [Option('o', "output", Required = false, HelpText = "Path to output CSV file. If not specified, results are printed to the console.")] + public string? Output { get; } + + public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) + { + logger.LogInformation("{productName} {version}", PRODUCT_NAME, ProductInfo.GetProductVersion()); + bool isSuccess = ConfigGenerator.TrySimulateAutoentities(this, loader, fileSystem); + if (isSuccess) + { + return CliReturnCode.SUCCESS; + } + else + { + logger.LogError("Failed to simulate autoentities."); + return CliReturnCode.GENERAL_ERROR; + } + } + } +} diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index c3e0352249..262cbc9145 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -28,6 +28,9 @@ public ConfigureOptions( string? dataSourceOptionsContainer = null, string? dataSourceOptionsSchema = null, bool? dataSourceOptionsSetSessionContext = null, + string? dataSourceHealthName = null, + bool? dataSourceUserDelegatedAuthEnabled = null, + string? dataSourceUserDelegatedAuthDatabaseAudience = null, int? depthLimit = null, bool? runtimeGraphQLEnabled = null, string? runtimeGraphQLPath = null, @@ -48,6 +51,7 @@ public ConfigureOptions( bool? runtimeMcpDmlToolsExecuteEntityEnabled = null, bool? runtimeCacheEnabled = null, int? runtimeCacheTtl = null, + CompressionLevel? runtimeCompressionLevel = null, HostMode? runtimeHostMode = null, IEnumerable? runtimeHostCorsOrigins = null, bool? runtimeHostCorsAllowCredentials = null, @@ -81,6 +85,9 @@ public ConfigureOptions( DataSourceOptionsContainer = dataSourceOptionsContainer; DataSourceOptionsSchema = dataSourceOptionsSchema; DataSourceOptionsSetSessionContext = dataSourceOptionsSetSessionContext; + DataSourceHealthName = dataSourceHealthName; + DataSourceUserDelegatedAuthEnabled = dataSourceUserDelegatedAuthEnabled; + DataSourceUserDelegatedAuthDatabaseAudience = dataSourceUserDelegatedAuthDatabaseAudience; // GraphQL DepthLimit = depthLimit; RuntimeGraphQLEnabled = runtimeGraphQLEnabled; @@ -105,6 +112,8 @@ public ConfigureOptions( // Cache RuntimeCacheEnabled = runtimeCacheEnabled; RuntimeCacheTTL = runtimeCacheTtl; + // Compression + RuntimeCompressionLevel = runtimeCompressionLevel; // Host RuntimeHostMode = runtimeHostMode; RuntimeHostCorsOrigins = runtimeHostCorsOrigins; @@ -152,6 +161,15 @@ public ConfigureOptions( [Option("data-source.options.set-session-context", Required = false, HelpText = "Enable session context. Allowed values: true (default), false.")] public bool? DataSourceOptionsSetSessionContext { get; } + [Option("data-source.health.name", Required = false, HelpText = "Identifier for data source in health check report.")] + public string? DataSourceHealthName { get; } + + [Option("data-source.user-delegated-auth.enabled", Required = false, HelpText = "Enable user-delegated authentication (OBO) for Azure SQL and SQL Server. Default: false (boolean).")] + public bool? DataSourceUserDelegatedAuthEnabled { get; } + + [Option("data-source.user-delegated-auth.database-audience", Required = false, HelpText = "Database resource identifier for token acquisition (e.g., https://database.windows.net for Azure SQL).")] + public string? DataSourceUserDelegatedAuthDatabaseAudience { get; } + [Option("runtime.graphql.depth-limit", Required = false, HelpText = "Max allowed depth of the nested query. Allowed values: (0,2147483647] inclusive. Default is infinity. Use -1 to remove limit.")] public int? DepthLimit { get; } @@ -212,6 +230,9 @@ public ConfigureOptions( [Option("runtime.cache.ttl-seconds", Required = false, HelpText = "Customize the DAB cache's global default time to live in seconds. Default: 5 seconds (Integer).")] public int? RuntimeCacheTTL { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] + public CompressionLevel? RuntimeCompressionLevel { get; } + [Option("runtime.host.mode", Required = false, HelpText = "Set the host running mode of DAB in Development or Production. Default: Development.")] public HostMode? RuntimeHostMode { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 78a5e63a7d..40bb6c7262 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -2,16 +2,20 @@ // Licensed under the MIT License. using System.Collections.ObjectModel; +using System.Data; using System.Diagnostics.CodeAnalysis; using System.IO.Abstractions; +using System.Text; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Service; using Cli.Commands; +using Microsoft.Data.SqlClient; using Microsoft.Extensions.Logging; using Serilog; using static Cli.Utils; @@ -643,6 +647,7 @@ private static bool TryUpdateConfiguredDataSourceOptions( DatabaseType dbType = runtimeConfig.DataSource.DatabaseType; string dataSourceConnectionString = runtimeConfig.DataSource.ConnectionString; DatasourceHealthCheckConfig? datasourceHealthCheckConfig = runtimeConfig.DataSource.Health; + UserDelegatedAuthOptions? userDelegatedAuthConfig = runtimeConfig.DataSource.UserDelegatedAuth; if (options.DataSourceDatabaseType is not null) { @@ -684,8 +689,71 @@ private static bool TryUpdateConfiguredDataSourceOptions( dbOptions.Add(namingPolicy.ConvertName(nameof(MsSqlOptions.SetSessionContext)), options.DataSourceOptionsSetSessionContext.Value); } + // Handle health.name option + if (options.DataSourceHealthName is not null) + { + // If there's no existing health config, create one with the name + // Note: Passing enabled: null results in Enabled = true at runtime (default behavior) + // but UserProvidedEnabled = false, so the enabled property won't be serialized to JSON. + // This ensures only the name property is written to the config file. + if (datasourceHealthCheckConfig is null) + { + datasourceHealthCheckConfig = new DatasourceHealthCheckConfig(enabled: null, name: options.DataSourceHealthName); + } + else + { + // Update the existing health config with the new name while preserving other settings. + // DatasourceHealthCheckConfig is a record (immutable), so we create a new instance. + // Preserve threshold only if it was explicitly set by the user + int? thresholdToPreserve = datasourceHealthCheckConfig.UserProvidedThresholdMs + ? datasourceHealthCheckConfig.ThresholdMs + : null; + // Preserve enabled only if it was explicitly set by the user + bool? enabledToPreserve = datasourceHealthCheckConfig.UserProvidedEnabled + ? datasourceHealthCheckConfig.Enabled + : null; + datasourceHealthCheckConfig = new DatasourceHealthCheckConfig( + enabled: enabledToPreserve, + name: options.DataSourceHealthName, + thresholdMs: thresholdToPreserve); + } + } + + // Handle user-delegated-auth options + if (options.DataSourceUserDelegatedAuthEnabled is not null + || options.DataSourceUserDelegatedAuthDatabaseAudience is not null) + { + // Determine the enabled state: use new value if provided, otherwise preserve existing + bool enabled = options.DataSourceUserDelegatedAuthEnabled + ?? userDelegatedAuthConfig?.Enabled + ?? false; + + // Validate that user-delegated-auth is only used with MSSQL when enabled=true + if (enabled && !DatabaseType.MSSQL.Equals(dbType)) + { + _logger.LogError("user-delegated-auth is only supported for database-type 'mssql'."); + return false; + } + + // Get database-audience: use new value if provided, otherwise preserve existing + string? databaseAudience = options.DataSourceUserDelegatedAuthDatabaseAudience + ?? userDelegatedAuthConfig?.DatabaseAudience; + + // Get provider: preserve existing or use default "EntraId" + string? provider = userDelegatedAuthConfig?.Provider ?? "EntraId"; + + // Create or update user-delegated-auth config + userDelegatedAuthConfig = new UserDelegatedAuthOptions( + Enabled: enabled, + Provider: provider, + DatabaseAudience: databaseAudience); + } + dbOptions = EnumerableUtilities.IsNullOrEmpty(dbOptions) ? null : dbOptions; - DataSource dataSource = new(dbType, dataSourceConnectionString, dbOptions, datasourceHealthCheckConfig); + DataSource dataSource = new(dbType, dataSourceConnectionString, dbOptions, datasourceHealthCheckConfig) + { + UserDelegatedAuth = userDelegatedAuthConfig + }; runtimeConfig = runtimeConfig with { DataSource = dataSource }; return runtimeConfig != null; @@ -849,6 +917,21 @@ private static bool TryUpdateConfiguredRuntimeOptions( } } + // Compression: Level + if (options.RuntimeCompressionLevel != null) + { + CompressionOptions updatedCompressionOptions = runtimeConfig?.Runtime?.Compression ?? new(); + bool status = TryUpdateConfiguredCompressionValues(options, ref updatedCompressionOptions); + if (status) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Compression = updatedCompressionOptions } }; + } + else + { + return false; + } + } + // Host: Mode, Cors.Origins, Cors.AllowCredentials, Authentication.Provider, Authentication.Jwt.Audience, Authentication.Jwt.Issuer if (options.RuntimeHostMode != null || options.RuntimeHostCorsOrigins != null || @@ -1226,6 +1309,37 @@ private static bool TryUpdateConfiguredCacheValues( } } + /// + /// Attempts to update the Config parameters in the Compression runtime settings based on the provided value. + /// Validates user-provided parameters and then returns true if the updated Compression options + /// need to be overwritten on the existing config parameters. + /// + /// options. + /// updatedCompressionOptions. + /// True if the value needs to be updated in the runtime config, else false + private static bool TryUpdateConfiguredCompressionValues( + ConfigureOptions options, + ref CompressionOptions updatedCompressionOptions) + { + try + { + // Runtime.Compression.Level + CompressionLevel? updatedValue = options?.RuntimeCompressionLevel; + if (updatedValue != null) + { + updatedCompressionOptions = updatedCompressionOptions with { Level = updatedValue.Value, UserProvidedLevel = true }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Compression.Level as '{updatedValue}'", updatedValue); + } + + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to configure RuntimeConfig.Compression with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Attempts to update the Config parameters in the Host runtime settings based on the provided value. /// Validates that any user-provided parameter value is valid and then returns true if the updated Host options @@ -2747,6 +2861,483 @@ public static bool TryAddTelemetry(AddTelemetryOptions options, FileSystemRuntim return WriteRuntimeConfigToFile(runtimeConfigFile, runtimeConfig, fileSystem); } + /// + /// Configures an autoentities definition in the runtime config. + /// This method updates or creates an autoentities definition with the specified patterns, template, and permissions. + /// + /// The autoentities configuration options provided by the user. + /// The config loader to read the existing config. + /// The filesystem used for reading and writing the config file. + /// True if the autoentities definition was successfully configured; otherwise, false. + public static bool TryConfigureAutoentities(AutoConfigOptions options, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) + { + if (!TryGetConfigFileBasedOnCliPrecedence(loader, options.Config, out string runtimeConfigFile)) + { + return false; + } + + if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig)) + { + _logger.LogError("Failed to read the config file: {runtimeConfigFile}.", runtimeConfigFile); + return false; + } + + // Get existing autoentities or create new collection + Dictionary autoEntitiesDictionary = runtimeConfig.Autoentities?.Autoentities != null + ? new Dictionary(runtimeConfig.Autoentities.Autoentities) + : new Dictionary(); + + // Get existing autoentity definition or create a new one + Autoentity? existingAutoentity = null; + if (autoEntitiesDictionary.TryGetValue(options.DefinitionName, out Autoentity? value)) + { + existingAutoentity = value; + } + + // Build patterns + AutoentityPatterns patterns = BuildAutoentityPatterns(options, existingAutoentity); + + // Build template + AutoentityTemplate? template = BuildAutoentityTemplate(options, existingAutoentity); + if (template is null) + { + return false; + } + + // Build permissions + EntityPermission[]? permissions = BuildAutoentityPermissions(options, existingAutoentity); + + // Check if permissions parsing failed (non-empty input but failed to parse) + if (permissions is null && options.Permissions is not null && options.Permissions.Count() > 0) + { + _logger.LogError("Failed to parse permissions."); + return false; + } + + // Create updated autoentity + Autoentity updatedAutoentity = new( + Patterns: patterns, + Template: template, + Permissions: permissions ?? existingAutoentity?.Permissions + ); + + // Update the dictionary + autoEntitiesDictionary[options.DefinitionName] = updatedAutoentity; + + // Update runtime config + runtimeConfig = runtimeConfig with + { + Autoentities = new RuntimeAutoentities(autoEntitiesDictionary) + }; + + return WriteRuntimeConfigToFile(runtimeConfigFile, runtimeConfig, fileSystem); + } + + /// + /// Builds the AutoentityPatterns object from the provided options and existing autoentity. + /// + private static AutoentityPatterns BuildAutoentityPatterns(AutoConfigOptions options, Autoentity? existingAutoentity) + { + string[]? include = null; + string[]? exclude = null; + string? name = null; + bool userProvidedInclude = false; + bool userProvidedExclude = false; + bool userProvidedName = false; + + // Start with existing values + if (existingAutoentity is not null) + { + include = existingAutoentity.Patterns.Include; + exclude = existingAutoentity.Patterns.Exclude; + name = existingAutoentity.Patterns.Name; + userProvidedInclude = existingAutoentity.Patterns.UserProvidedIncludeOptions; + userProvidedExclude = existingAutoentity.Patterns.UserProvidedExcludeOptions; + userProvidedName = existingAutoentity.Patterns.UserProvidedNameOptions; + } + + // Override with new values if provided + if (options.PatternsInclude is not null && options.PatternsInclude.Any()) + { + include = options.PatternsInclude.ToArray(); + userProvidedInclude = true; + _logger.LogInformation("Updated patterns.include for definition '{DefinitionName}'", options.DefinitionName); + } + + if (options.PatternsExclude is not null && options.PatternsExclude.Any()) + { + exclude = options.PatternsExclude.ToArray(); + userProvidedExclude = true; + _logger.LogInformation("Updated patterns.exclude for definition '{DefinitionName}'", options.DefinitionName); + } + + if (!string.IsNullOrWhiteSpace(options.PatternsName)) + { + name = options.PatternsName; + userProvidedName = true; + _logger.LogInformation("Updated patterns.name for definition '{DefinitionName}'", options.DefinitionName); + } + + return new AutoentityPatterns(Include: include, Exclude: exclude, Name: name) + { + UserProvidedIncludeOptions = userProvidedInclude, + UserProvidedExcludeOptions = userProvidedExclude, + UserProvidedNameOptions = userProvidedName + }; + } + + /// + /// Builds the AutoentityTemplate object from the provided options and existing autoentity. + /// Returns null if validation fails. + /// + private static AutoentityTemplate? BuildAutoentityTemplate(AutoConfigOptions options, Autoentity? existingAutoentity) + { + // Start with existing values or defaults + EntityMcpOptions? mcp = existingAutoentity?.Template.Mcp; + EntityRestOptions rest = existingAutoentity?.Template.Rest ?? new EntityRestOptions(); + EntityGraphQLOptions graphQL = existingAutoentity?.Template.GraphQL ?? new EntityGraphQLOptions(string.Empty, string.Empty); + EntityHealthCheckConfig health = existingAutoentity?.Template.Health ?? new EntityHealthCheckConfig(); + EntityCacheOptions cache = existingAutoentity?.Template.Cache ?? new EntityCacheOptions(); + + bool userProvidedMcp = existingAutoentity?.Template.UserProvidedMcpOptions ?? false; + bool userProvidedRest = existingAutoentity?.Template.UserProvidedRestOptions ?? false; + bool userProvidedGraphQL = existingAutoentity?.Template.UserProvidedGraphQLOptions ?? false; + bool userProvidedHealth = existingAutoentity?.Template.UserProvidedHealthOptions ?? false; + bool userProvidedCache = existingAutoentity?.Template.UserProvidedCacheOptions ?? false; + + // Update MCP options + if (!string.IsNullOrWhiteSpace(options.TemplateMcpDmlTool)) + { + if (!bool.TryParse(options.TemplateMcpDmlTool, out bool mcpDmlToolValue)) + { + _logger.LogError("Invalid value for template.mcp.dml-tool: {value}. Valid values are: true, false", options.TemplateMcpDmlTool); + return null; + } + + // TODO: Task #2949. Once autoentities is able to support stored procedures, we will need to change this in order to allow the CLI to edit the custom tool section. + bool? customToolEnabled = mcp?.UserProvidedCustomToolEnabled == true ? mcp.CustomToolEnabled : null; + bool? dmlToolValue = mcpDmlToolValue; + mcp = new EntityMcpOptions(customToolEnabled: customToolEnabled, dmlToolsEnabled: dmlToolValue); + userProvidedMcp = true; + _logger.LogInformation("Updated template.mcp.dml-tool for definition '{DefinitionName}'", options.DefinitionName); + } + + // Update REST options + if (options.TemplateRestEnabled is not null) + { + rest = rest with { Enabled = options.TemplateRestEnabled.Value }; + userProvidedRest = true; + _logger.LogInformation("Updated template.rest.enabled for definition '{DefinitionName}'", options.DefinitionName); + } + + // Update GraphQL options + if (options.TemplateGraphqlEnabled is not null) + { + graphQL = graphQL with { Enabled = options.TemplateGraphqlEnabled.Value }; + userProvidedGraphQL = true; + _logger.LogInformation("Updated template.graphql.enabled for definition '{DefinitionName}'", options.DefinitionName); + } + + // Update Health options + if (options.TemplateHealthEnabled is not null) + { + health = new EntityHealthCheckConfig( + enabled: options.TemplateHealthEnabled.Value, + first: health.UserProvidedFirst ? health.First : null, + thresholdMs: health.UserProvidedThresholdMs ? health.ThresholdMs : null + ); + userProvidedHealth = true; + _logger.LogInformation("Updated template.health.enabled for definition '{DefinitionName}'", options.DefinitionName); + } + + // Update Cache options + bool cacheUpdated = false; + bool? cacheEnabled = cache.Enabled; + int? cacheTtl = cache.UserProvidedTtlOptions ? cache.TtlSeconds : null; + EntityCacheLevel? cacheLevel = cache.UserProvidedLevelOptions ? cache.Level : null; + + if (options.TemplateCacheEnabled is not null) + { + cacheEnabled = options.TemplateCacheEnabled.Value; + cacheUpdated = true; + _logger.LogInformation("Updated template.cache.enabled for definition '{DefinitionName}'", options.DefinitionName); + } + + if (options.TemplateCacheTtlSeconds is not null) + { + cacheTtl = options.TemplateCacheTtlSeconds.Value; + bool status = RuntimeConfigValidatorUtil.IsTTLValid(ttl: (int)cacheTtl); + cacheUpdated = true; + if (status) + { + _logger.LogInformation("Updated template.cache.ttl-seconds for definition '{DefinitionName}'", options.DefinitionName); + } + else + { + _logger.LogError("Failed to update Runtime.Cache.ttl-seconds as '{updatedValue}' value in TTL is not valid.", cacheTtl); + return null; + } + } + + if (!string.IsNullOrWhiteSpace(options.TemplateCacheLevel)) + { + if (!Enum.TryParse(options.TemplateCacheLevel, ignoreCase: true, out EntityCacheLevel cacheLevelValue)) + { + _logger.LogError(EnumExtensions.GenerateMessageForInvalidInput(options.TemplateCacheLevel)); + return null; + } + + cacheLevel = cacheLevelValue; + cacheUpdated = true; + _logger.LogInformation("Updated template.cache.level for definition '{DefinitionName}'", options.DefinitionName); + } + + if (cacheUpdated) + { + cache = new EntityCacheOptions(Enabled: cacheEnabled, TtlSeconds: cacheTtl, Level: cacheLevel); + userProvidedCache = true; + } + + return new AutoentityTemplate( + Rest: rest, + GraphQL: graphQL, + Mcp: mcp, + Health: health, + Cache: cache + ) + { + UserProvidedMcpOptions = userProvidedMcp, + UserProvidedRestOptions = userProvidedRest, + UserProvidedGraphQLOptions = userProvidedGraphQL, + UserProvidedHealthOptions = userProvidedHealth, + UserProvidedCacheOptions = userProvidedCache + }; + } + + /// + /// Builds the permissions array from the provided options and existing autoentity. + /// + private static EntityPermission[]? BuildAutoentityPermissions(AutoConfigOptions options, Autoentity? existingAutoentity) + { + if (options.Permissions is null || !options.Permissions.Any()) + { + return existingAutoentity?.Permissions; + } + + // Parse the permissions + EntityPermission[]? parsedPermissions = ParsePermission(options.Permissions, null, null, null); + if (parsedPermissions is not null) + { + _logger.LogInformation("Updated permissions for definition '{DefinitionName}'", options.DefinitionName); + } + + return parsedPermissions; + } + + // Column names returned by the autoentities SQL query. + private const string AUTOENTITIES_COLUMN_ENTITY_NAME = "entity_name"; + private const string AUTOENTITIES_COLUMN_OBJECT = "object"; + private const string AUTOENTITIES_COLUMN_SCHEMA = "schema"; + + /// + /// Simulates the autoentities generation by querying the database and displaying + /// which entities would be created for each autoentities filter definition. + /// When an output file path is provided, results are written as CSV; otherwise they are printed to the console. + /// + /// The simulate options provided by the user. + /// The config loader to read the existing config. + /// The filesystem used for reading the config file and writing output. + /// True if the simulation completed successfully; otherwise, false. + public static bool TrySimulateAutoentities(AutoConfigSimulateOptions options, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) + { + if (!TryGetConfigFileBasedOnCliPrecedence(loader, options.Config, out string runtimeConfigFile)) + { + return false; + } + + // Load config with env var replacement so the connection string is fully resolved. + DeserializationVariableReplacementSettings replacementSettings = new(doReplaceEnvVar: true); + if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig, replacementSettings: replacementSettings)) + { + _logger.LogError("Failed to read the config file: {runtimeConfigFile}.", runtimeConfigFile); + return false; + } + + if (runtimeConfig.DataSource.DatabaseType != DatabaseType.MSSQL) + { + _logger.LogError("Autoentities simulation is only supported for MSSQL databases. Current database type: {DatabaseType}.", runtimeConfig.DataSource.DatabaseType); + return false; + } + + if (runtimeConfig.Autoentities?.Autoentities is null || runtimeConfig.Autoentities.Autoentities.Count == 0) + { + _logger.LogError("No autoentities definitions found in the config file."); + return false; + } + + string connectionString = runtimeConfig.DataSource.ConnectionString; + if (string.IsNullOrWhiteSpace(connectionString)) + { + _logger.LogError("Connection string is missing or empty in config file."); + return false; + } + + MsSqlQueryBuilder queryBuilder = new(); + string query = queryBuilder.BuildGetAutoentitiesQuery(); + + Dictionary> results = new(); + + try + { + using SqlConnection connection = new(connectionString); + connection.Open(); + + foreach ((string filterName, Autoentity autoentity) in runtimeConfig.Autoentities.Autoentities) + { + string include = string.Join(",", autoentity.Patterns.Include); + string exclude = string.Join(",", autoentity.Patterns.Exclude); + string namePattern = autoentity.Patterns.Name; + + List<(string EntityName, string SchemaName, string ObjectName)> filterResults = new(); + + using SqlCommand command = new(query, connection); + SqlParameter includeParameter = new("@include_pattern", SqlDbType.NVarChar) + { + Value = include + }; + SqlParameter excludeParameter = new("@exclude_pattern", SqlDbType.NVarChar) + { + Value = exclude + }; + SqlParameter namePatternParameter = new("@name_pattern", SqlDbType.NVarChar) + { + Value = namePattern + }; + + command.Parameters.Add(includeParameter); + command.Parameters.Add(excludeParameter); + command.Parameters.Add(namePatternParameter); + using SqlDataReader reader = command.ExecuteReader(); + while (reader.Read()) + { + string entityName = reader[AUTOENTITIES_COLUMN_ENTITY_NAME]?.ToString() ?? string.Empty; + string objectName = reader[AUTOENTITIES_COLUMN_OBJECT]?.ToString() ?? string.Empty; + string schemaName = reader[AUTOENTITIES_COLUMN_SCHEMA]?.ToString() ?? string.Empty; + + if (!string.IsNullOrWhiteSpace(entityName) && !string.IsNullOrWhiteSpace(objectName)) + { + filterResults.Add((entityName, schemaName, objectName)); + } + } + + results[filterName] = filterResults; + } + } + catch (Exception ex) + { + _logger.LogError("Failed to query the database: {Message}", ex.Message); + return false; + } + + if (!string.IsNullOrWhiteSpace(options.Output)) + { + return WriteSimulationResultsToCsvFile(options.Output, results, fileSystem); + } + else + { + WriteSimulationResultsToConsole(results); + return true; + } + } + + /// + /// Writes the autoentities simulation results to the console in a human-readable format. + /// Results are grouped by filter name with entity-to-database-object mappings. + /// + /// The simulation results keyed by filter (definition) name. + private static void WriteSimulationResultsToConsole(Dictionary> results) + { + Console.WriteLine("AutoEntities Simulation Results"); + Console.WriteLine(); + + foreach ((string filterName, List<(string EntityName, string SchemaName, string ObjectName)> matches) in results) + { + Console.WriteLine($"Filter: {filterName}"); + Console.WriteLine($"Matches: {matches.Count}"); + Console.WriteLine(); + + if (matches.Count == 0) + { + Console.WriteLine("(no matches)"); + } + else + { + int maxEntityNameLength = matches.Max(m => m.EntityName.Length); + foreach ((string entityName, string schemaName, string objectName) in matches) + { + Console.WriteLine($"{entityName.PadRight(maxEntityNameLength)} -> {schemaName}.{objectName}"); + } + } + + Console.WriteLine(); + } + } + + /// + /// Writes the autoentities simulation results to a CSV file. + /// The file includes a header row followed by one row per matched entity. + /// If the file already exists it is overwritten. + /// + /// The file path to write the CSV output to. + /// The simulation results keyed by filter (definition) name. + /// The filesystem abstraction used for writing the file. + /// True if the file was written successfully; otherwise, false. + private static bool WriteSimulationResultsToCsvFile( + string outputPath, + Dictionary> results, + IFileSystem fileSystem) + { + try + { + StringBuilder sb = new(); + sb.AppendLine("filter_name,entity_name,database_object"); + + foreach ((string filterName, List<(string EntityName, string SchemaName, string ObjectName)> matches) in results) + { + foreach ((string entityName, string schemaName, string objectName) in matches) + { + sb.AppendLine($"{QuoteCsvValue(filterName)},{QuoteCsvValue(entityName)},{QuoteCsvValue($"{schemaName}.{objectName}")}"); + } + } + + fileSystem.File.WriteAllText(outputPath, sb.ToString()); + _logger.LogInformation("Simulation results written to {outputPath}.", outputPath); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to write output file: {Message}", ex.Message); + return false; + } + } + + /// + /// Quotes a value for inclusion in a CSV field. + /// If the value contains a comma, double-quote, or newline, it is wrapped in double-quotes + /// and any embedded double-quotes are escaped by doubling them. + /// + /// The value to quote. + /// A properly escaped CSV field value. + private static string QuoteCsvValue(string value) + { + if (value.Contains(',') || value.Contains('"') || value.Contains('\n') || value.Contains('\r')) + { + return $"\"{value.Replace("\"", "\"\"")}\""; + } + + return value; + } + /// /// Attempts to update the Azure Key Vault configuration options based on the provided values. /// Validates that any user-provided parameter value is valid and updates the runtime configuration accordingly. diff --git a/src/Cli/Program.cs b/src/Cli/Program.cs index 036f3dc2a3..de16ed27f5 100644 --- a/src/Cli/Program.cs +++ b/src/Cli/Program.cs @@ -58,7 +58,7 @@ public static int Execute(string[] args, ILogger cliLogger, IFileSystem fileSyst }); // Parsing user arguments and executing required methods. - int result = parser.ParseArguments(args) + int result = parser.ParseArguments(args) .MapResult( (InitOptions options) => options.Handler(cliLogger, loader, fileSystem), (AddOptions options) => options.Handler(cliLogger, loader, fileSystem), @@ -67,6 +67,8 @@ public static int Execute(string[] args, ILogger cliLogger, IFileSystem fileSyst (ValidateOptions options) => options.Handler(cliLogger, loader, fileSystem), (AddTelemetryOptions options) => options.Handler(cliLogger, loader, fileSystem), (ConfigureOptions options) => options.Handler(cliLogger, loader, fileSystem), + (AutoConfigOptions options) => options.Handler(cliLogger, loader, fileSystem), + (AutoConfigSimulateOptions options) => options.Handler(cliLogger, loader, fileSystem), (ExportOptions options) => options.Handler(cliLogger, loader, fileSystem), errors => DabCliParserErrorHandler.ProcessErrorsAndReturnExitCode(errors)); diff --git a/src/Config/Converters/AutoentityConverter.cs b/src/Config/Converters/AutoentityConverter.cs index 5c09ed8e7b..b97b3e7be3 100644 --- a/src/Config/Converters/AutoentityConverter.cs +++ b/src/Config/Converters/AutoentityConverter.cs @@ -90,6 +90,7 @@ public override void Write(Utf8JsonWriter writer, Autoentity value, JsonSerializ AutoentityTemplate? template = value?.Template; if (template?.UserProvidedRestOptions is true || template?.UserProvidedGraphQLOptions is true + || template?.UserProvidedMcpOptions is true || template?.UserProvidedHealthOptions is true || template?.UserProvidedCacheOptions is true) { @@ -99,7 +100,7 @@ public override void Write(Utf8JsonWriter writer, Autoentity value, JsonSerializ autoentityTemplateConverter.Write(writer, template, options); } - if (value?.Permissions is not null) + if (value?.Permissions is not null && value.Permissions.Length > 0) { writer.WritePropertyName("permissions"); JsonSerializer.Serialize(writer, value.Permissions, options); diff --git a/src/Config/Converters/AutoentityTemplateConverter.cs b/src/Config/Converters/AutoentityTemplateConverter.cs index 275cfc4314..8f5cb1276c 100644 --- a/src/Config/Converters/AutoentityTemplateConverter.cs +++ b/src/Config/Converters/AutoentityTemplateConverter.cs @@ -113,7 +113,17 @@ public override void Write(Utf8JsonWriter writer, AutoentityTemplate value, Json if (value?.UserProvidedGraphQLOptions is true) { writer.WritePropertyName("graphql"); - JsonSerializer.Serialize(writer, value.GraphQL, options); + // For autoentities template, only write the enabled property + // The type (singular/plural) is determined by the generated entities + writer.WriteStartObject(); + writer.WriteBoolean("enabled", value.GraphQL.Enabled); + writer.WriteEndObject(); + } + + if (value?.UserProvidedMcpOptions is true) + { + writer.WritePropertyName("mcp"); + JsonSerializer.Serialize(writer, value.Mcp, options); } if (value?.UserProvidedHealthOptions is true) diff --git a/src/Config/Converters/CompressionOptionsConverterFactory.cs b/src/Config/Converters/CompressionOptionsConverterFactory.cs new file mode 100644 index 0000000000..092b2ff74a --- /dev/null +++ b/src/Config/Converters/CompressionOptionsConverterFactory.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Defines how DAB reads and writes the compression options (JSON). +/// +internal class CompressionOptionsConverterFactory : JsonConverterFactory +{ + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(CompressionOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new CompressionOptionsConverter(); + } + + private class CompressionOptionsConverter : JsonConverter + { + /// + /// Defines how DAB reads the compression options and defines which values are + /// used to instantiate CompressionOptions. + /// + public override CompressionOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object."); + } + + CompressionLevel level = CompressionOptions.DEFAULT_LEVEL; + bool userProvidedLevel = false; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType == JsonTokenType.PropertyName) + { + string? propertyName = reader.GetString(); + reader.Read(); + + if (string.Equals(propertyName, "level", StringComparison.OrdinalIgnoreCase)) + { + string? levelStr = reader.GetString(); + if (levelStr is not null) + { + if (Enum.TryParse(levelStr, ignoreCase: true, out CompressionLevel parsedLevel)) + { + level = parsedLevel; + userProvidedLevel = true; + } + else + { + throw new JsonException($"Invalid compression level: '{levelStr}'. Valid values are: optimal, fastest, none."); + } + } + } + else + { + // Skip unknown properties and their values (including objects/arrays) + reader.Skip(); + } + } + } + + return new CompressionOptions(level) with { UserProvidedLevel = userProvidedLevel }; + } + + /// + /// When writing the CompressionOptions back to a JSON file, only write the level + /// property and value when it was provided by the user. + /// + public override void Write(Utf8JsonWriter writer, CompressionOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + if (value is not null && value.UserProvidedLevel) + { + writer.WritePropertyName("level"); + writer.WriteStringValue(value.Level.ToString().ToLowerInvariant()); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/Converters/DataSourceConverterFactory.cs b/src/Config/Converters/DataSourceConverterFactory.cs index 1788ebf2b4..62d37c55e2 100644 --- a/src/Config/Converters/DataSourceConverterFactory.cs +++ b/src/Config/Converters/DataSourceConverterFactory.cs @@ -51,12 +51,16 @@ public DataSourceConverter(DeserializationVariableReplacementSettings? replaceme string connectionString = string.Empty; DatasourceHealthCheckConfig? health = null; Dictionary? datasourceOptions = null; + UserDelegatedAuthOptions? userDelegatedAuth = null; while (reader.Read()) { if (reader.TokenType is JsonTokenType.EndObject) { - return new DataSource(databaseType, connectionString, datasourceOptions, health); + return new DataSource(databaseType, connectionString, datasourceOptions, health) + { + UserDelegatedAuth = userDelegatedAuth + }; } if (reader.TokenType is JsonTokenType.PropertyName) @@ -136,6 +140,20 @@ public DataSourceConverter(DeserializationVariableReplacementSettings? replaceme datasourceOptions = optionsDict; } + break; + case "user-delegated-auth": + if (reader.TokenType != JsonTokenType.Null) + { + try + { + userDelegatedAuth = JsonSerializer.Deserialize(ref reader, options); + } + catch (Exception e) + { + throw new JsonException($"Error while deserializing DataSource user-delegated-auth: {e.Message}"); + } + } + break; default: throw new JsonException($"Unexpected property {propertyName} while deserializing DataSource."); diff --git a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs index d8286ff7a0..cbe4511daa 100644 --- a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs +++ b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs @@ -114,11 +114,21 @@ public HealthCheckOptionsConverter(DeserializationVariableReplacementSettings? r public override void Write(Utf8JsonWriter writer, DatasourceHealthCheckConfig value, JsonSerializerOptions options) { - if (value?.UserProvidedEnabled is true) + // Write the health object if any of these conditions are met: + // - enabled was explicitly provided by the user + // - name property has a value + // - threshold was explicitly provided by the user + if (value?.UserProvidedEnabled is true || value?.Name is not null || value?.UserProvidedThresholdMs is true) { writer.WriteStartObject(); - writer.WritePropertyName("enabled"); - JsonSerializer.Serialize(writer, value.Enabled, options); + + // Only write enabled if it was explicitly provided by the user + if (value?.UserProvidedEnabled is true) + { + writer.WritePropertyName("enabled"); + JsonSerializer.Serialize(writer, value.Enabled, options); + } + if (value?.Name is not null) { writer.WritePropertyName("name"); diff --git a/src/Config/Converters/RuntimeAutoentitiesConverter.cs b/src/Config/Converters/RuntimeAutoentitiesConverter.cs index b65bcb9989..597ef18523 100644 --- a/src/Config/Converters/RuntimeAutoentitiesConverter.cs +++ b/src/Config/Converters/RuntimeAutoentitiesConverter.cs @@ -29,7 +29,7 @@ class RuntimeAutoentitiesConverter : JsonConverter public override void Write(Utf8JsonWriter writer, RuntimeAutoentities value, JsonSerializerOptions options) { writer.WriteStartObject(); - foreach ((string key, Autoentity autoEntity) in value.AutoEntities) + foreach ((string key, Autoentity autoEntity) in value.Autoentities) { writer.WritePropertyName(key); JsonSerializer.Serialize(writer, autoEntity, options); diff --git a/src/Config/DataApiBuilderException.cs b/src/Config/DataApiBuilderException.cs index b7696c4deb..95fe916c75 100644 --- a/src/Config/DataApiBuilderException.cs +++ b/src/Config/DataApiBuilderException.cs @@ -20,6 +20,11 @@ public class DataApiBuilderException : Exception public const string GRAPHQL_MUTATION_FIELD_AUTHZ_FAILURE = "Unauthorized due to one or more fields in this mutation."; public const string GRAPHQL_GROUPBY_FIELD_AUTHZ_FAILURE = "Access forbidden to field '{0}' referenced in the groupBy argument."; public const string GRAPHQL_AGGREGATION_FIELD_AUTHZ_FAILURE = "Access forbidden to field '{0}' referenced in the aggregation function '{1}'."; + public const string OBO_IDENTITY_CLAIMS_MISSING = "User-delegated authentication failed: Neither 'oid' nor 'sub' claim found in the access token."; + public const string OBO_TENANT_CLAIM_MISSING = "User-delegated authentication failed: 'tid' (tenant id) claim not found in the access token."; + public const string OBO_TOKEN_ACQUISITION_FAILED = "User-delegated authentication failed: Unable to acquire database access token on behalf of the user."; + public const string OBO_MISSING_USER_CONTEXT = "User-delegated authentication failed: Missing or invalid 'Authorization: Bearer ' header. OBO requires a valid user token to exchange for database access."; + public const string OBO_MISSING_DATABASE_AUDIENCE = "User-delegated authentication failed: 'database-audience' is not configured in the data source's user-delegated-auth settings."; public enum SubStatusCodes { @@ -127,7 +132,11 @@ public enum SubStatusCodes /// /// Error due to client input validation failure. /// - DatabaseInputError + DatabaseInputError, + /// + /// User-delegated (OBO) authentication failed due to missing identity claims. + /// + OboAuthenticationFailure } public HttpStatusCode StatusCode { get; } diff --git a/src/Config/DatabasePrimitives/DatabaseObject.cs b/src/Config/DatabasePrimitives/DatabaseObject.cs index 8636e8c005..be1eff45ba 100644 --- a/src/Config/DatabasePrimitives/DatabaseObject.cs +++ b/src/Config/DatabasePrimitives/DatabaseObject.cs @@ -43,13 +43,15 @@ public override bool Equals(object? other) public bool Equals(DatabaseObject? other) { return other is not null && - SchemaName.Equals(other.SchemaName) && - Name.Equals(other.Name); + string.Equals(SchemaName, other.SchemaName, StringComparison.OrdinalIgnoreCase) && + string.Equals(Name, other.Name, StringComparison.OrdinalIgnoreCase); } public override int GetHashCode() { - return HashCode.Combine(SchemaName, Name); + return HashCode.Combine( + SchemaName is null ? 0 : StringComparer.OrdinalIgnoreCase.GetHashCode(SchemaName), + Name is null ? 0 : StringComparer.OrdinalIgnoreCase.GetHashCode(Name)); } /// diff --git a/src/Config/ObjectModel/CompressionLevel.cs b/src/Config/ObjectModel/CompressionLevel.cs new file mode 100644 index 0000000000..f60a81c45f --- /dev/null +++ b/src/Config/ObjectModel/CompressionLevel.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Specifies the compression level for HTTP response compression. +/// +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CompressionLevel +{ + /// + /// Provides the best compression ratio at the cost of speed. + /// + Optimal, + + /// + /// Provides the fastest compression at the cost of compression ratio. + /// + Fastest, + + /// + /// Disables compression. + /// + None +} diff --git a/src/Config/ObjectModel/CompressionOptions.cs b/src/Config/ObjectModel/CompressionOptions.cs new file mode 100644 index 0000000000..c06f926673 --- /dev/null +++ b/src/Config/ObjectModel/CompressionOptions.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Configuration options for HTTP response compression. +/// +public record CompressionOptions +{ + /// + /// Default compression level is Optimal. + /// + public const CompressionLevel DEFAULT_LEVEL = CompressionLevel.Optimal; + + /// + /// The compression level to use for HTTP response compression. + /// + [JsonPropertyName("level")] + public CompressionLevel Level { get; init; } = DEFAULT_LEVEL; + + /// + /// Flag which informs CLI and JSON serializer whether to write Level + /// property and value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedLevel { get; init; } = false; + + [JsonConstructor] + public CompressionOptions(CompressionLevel Level = DEFAULT_LEVEL) + { + this.Level = Level; + this.UserProvidedLevel = true; + } + + /// + /// Default parameterless constructor for cases where no compression level is specified. + /// + public CompressionOptions() + { + this.Level = DEFAULT_LEVEL; + this.UserProvidedLevel = false; + } +} diff --git a/src/Config/ObjectModel/DataSource.cs b/src/Config/ObjectModel/DataSource.cs index d1a2456ef9..e04acdfa37 100644 --- a/src/Config/ObjectModel/DataSource.cs +++ b/src/Config/ObjectModel/DataSource.cs @@ -40,6 +40,20 @@ public int DatasourceThresholdMs } } + /// + /// Configuration for user-delegated authentication (OBO) against the + /// configured database. + /// + [JsonPropertyName("user-delegated-auth")] + public UserDelegatedAuthOptions? UserDelegatedAuth { get; init; } + + /// + /// Indicates whether user-delegated authentication is enabled for this data source. + /// + [JsonIgnore] + public bool IsUserDelegatedAuthEnabled => + UserDelegatedAuth is not null && UserDelegatedAuth.Enabled; + /// /// Converts the Options dictionary into a typed options object. /// May return null if the dictionary is null. @@ -111,3 +125,67 @@ public record CosmosDbNoSQLDataSourceOptions(string? Database, string? Container /// Options for MsSql database. /// public record MsSqlOptions(bool SetSessionContext = true) : IDataSourceOptions; + +/// +/// Options for user-delegated authentication (OBO) for a data source. +/// +/// When OBO is NOT enabled (default): DAB connects to the database using a single application principal, +/// either via Managed Identity or credentials supplied in the connection string. All requests execute +/// under the same database identity regardless of which user made the API call. +/// +/// When OBO IS enabled: DAB exchanges the calling user's JWT for a database access token using the +/// On-Behalf-Of flow. This allows DAB to connect to the database as the actual user, enabling +/// Row-Level Security (RLS) filtering based on user identity. +/// +/// OBO requires an Azure AD App Registration (separate from the DAB service's Managed Identity). +/// The operator deploying DAB must set the following environment variables for the OBO App Registration, +/// which DAB reads at startup via Environment.GetEnvironmentVariable(): +/// - DAB_OBO_CLIENT_ID: The Application (client) ID of the OBO App Registration +/// - DAB_OBO_TENANT_ID: The Directory (tenant) ID where the OBO App Registration is registered +/// - DAB_OBO_CLIENT_SECRET: The client secret of the OBO App Registration (not a user secret) +/// +/// These credentials belong to the OBO App Registration, which acts as a confidential client to exchange +/// the incoming user JWT for a database access token. The user provides only their JWT; DAB uses the +/// App Registration credentials to perform the OBO token exchange on their behalf. +/// +/// These can be set in the hosting environment (e.g., Azure Container Apps secrets, Kubernetes secrets, +/// Docker environment variables, or local shell environment). +/// +/// Note: DAB-specific prefixes (DAB_OBO_*) are used instead of AZURE_* to avoid conflict with +/// DefaultAzureCredential, which interprets AZURE_CLIENT_ID as a User-Assigned Managed Identity ID. +/// At startup (when no user context is available), DAB falls back to Managed Identity for metadata operations. +/// +/// Whether user-delegated authentication is enabled. +/// The authentication provider (currently only EntraId is supported). +/// Audience used when acquiring database tokens on behalf of the user. +public record UserDelegatedAuthOptions( + [property: JsonPropertyName("enabled")] bool Enabled = false, + [property: JsonPropertyName("provider")] string? Provider = null, + [property: JsonPropertyName("database-audience")] string? DatabaseAudience = null) +{ + /// + /// Default duration, in minutes, to cache tokens for a given delegated identity. + /// With a 5-minute early refresh buffer, tokens are refreshed at the 40-minute mark. + /// + public const int DEFAULT_TOKEN_CACHE_DURATION_MINUTES = 45; + + /// + /// Environment variable name for OBO App Registration client ID. + /// Uses DAB-specific prefix to avoid conflict with AZURE_CLIENT_ID which is + /// interpreted by DefaultAzureCredential/ManagedIdentityCredential as a + /// User-Assigned Managed Identity ID. + /// + public const string DAB_OBO_CLIENT_ID_ENV_VAR = "DAB_OBO_CLIENT_ID"; + + /// + /// Environment variable name for OBO App Registration client secret. + /// Used for On-Behalf-Of token exchange. + /// + public const string DAB_OBO_CLIENT_SECRET_ENV_VAR = "DAB_OBO_CLIENT_SECRET"; + + /// + /// Environment variable name for OBO tenant ID. + /// Uses DAB-specific prefix for consistency with OBO client ID. + /// + public const string DAB_OBO_TENANT_ID_ENV_VAR = "DAB_OBO_TENANT_ID"; +} diff --git a/src/Config/ObjectModel/RuntimeAutoentities.cs b/src/Config/ObjectModel/RuntimeAutoentities.cs index 0fec45f5a1..4148e9ba02 100644 --- a/src/Config/ObjectModel/RuntimeAutoentities.cs +++ b/src/Config/ObjectModel/RuntimeAutoentities.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Collections; using System.Text.Json.Serialization; using Azure.DataApiBuilder.Config.Converters; @@ -10,19 +11,29 @@ namespace Azure.DataApiBuilder.Config.ObjectModel; /// Represents a collection of available from the RuntimeConfig. /// [JsonConverter(typeof(RuntimeAutoentitiesConverter))] -public record RuntimeAutoentities +public record RuntimeAutoentities : IEnumerable> { /// /// The collection of available from the RuntimeConfig. /// - public IReadOnlyDictionary AutoEntities { get; init; } + public IReadOnlyDictionary Autoentities { get; init; } /// /// Creates a new instance of the class using a collection of entities. /// - /// The collection of auto-entities to map to RuntimeAutoentities. - public RuntimeAutoentities(IReadOnlyDictionary autoEntities) + /// The collection of auto-entities to map to RuntimeAutoentities. + public RuntimeAutoentities(IReadOnlyDictionary autoentities) { - AutoEntities = autoEntities; + Autoentities = autoentities; + } + + public IEnumerator> GetEnumerator() + { + return Autoentities.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); } } diff --git a/src/Config/ObjectModel/RuntimeConfig.cs b/src/Config/ObjectModel/RuntimeConfig.cs index 1e567da1cd..b5c241a0a4 100644 --- a/src/Config/ObjectModel/RuntimeConfig.cs +++ b/src/Config/ObjectModel/RuntimeConfig.cs @@ -25,7 +25,7 @@ public record RuntimeConfig [JsonPropertyName("azure-key-vault")] public AzureKeyVaultOptions? AzureKeyVault { get; init; } - public RuntimeAutoentities? Autoentities { get; init; } + public RuntimeAutoentities Autoentities { get; init; } public virtual RuntimeEntities Entities { get; init; } @@ -216,6 +216,8 @@ Runtime.GraphQL.FeatureFlags is not null && private Dictionary _entityNameToDataSourceName = new(); + private Dictionary _autoentityNameToDataSourceName = new(); + private Dictionary _entityPathNameToEntityName = new(); /// @@ -245,6 +247,21 @@ public bool TryGetEntityNameFromPath(string entityPathName, [NotNullWhen(true)] return _entityPathNameToEntityName.TryGetValue(entityPathName, out entityName); } + public bool TryAddEntityNameToDataSourceName(string entityName) + { + return _entityNameToDataSourceName.TryAdd(entityName, this.DefaultDataSourceName); + } + + public bool TryAddGeneratedAutoentityNameToDataSourceName(string entityName, string autoEntityDefinition) + { + if (_autoentityNameToDataSourceName.TryGetValue(autoEntityDefinition, out string? dataSourceName)) + { + return _entityNameToDataSourceName.TryAdd(entityName, dataSourceName); + } + + return false; + } + /// /// Constructor for runtimeConfig. /// To be used when setting up from cli json scenario. @@ -268,8 +285,8 @@ public RuntimeConfig( this.DataSource = DataSource; this.Runtime = Runtime; this.AzureKeyVault = AzureKeyVault; - this.Entities = Entities; - this.Autoentities = Autoentities; + this.Entities = Entities ?? new RuntimeEntities(new Dictionary()); + this.Autoentities = Autoentities ?? new RuntimeAutoentities(new Dictionary()); this.DefaultDataSourceName = Guid.NewGuid().ToString(); if (this.DataSource is null) @@ -287,17 +304,29 @@ public RuntimeConfig( }; _entityNameToDataSourceName = new Dictionary(); - if (Entities is null) + if (Entities is null && this.Entities.Entities.Count == 0 && + Autoentities is null && this.Autoentities.Autoentities.Count == 0) { throw new DataApiBuilderException( - message: "entities is a mandatory property in DAB Config", + message: "Configuration file should contain either at least the entities or autoentities property", statusCode: HttpStatusCode.UnprocessableEntity, subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError); } - foreach (KeyValuePair entity in Entities) + if (Entities is not null) { - _entityNameToDataSourceName.TryAdd(entity.Key, this.DefaultDataSourceName); + foreach (KeyValuePair entity in Entities) + { + _entityNameToDataSourceName.TryAdd(entity.Key, this.DefaultDataSourceName); + } + } + + if (Autoentities is not null) + { + foreach (KeyValuePair autoentity in Autoentities) + { + _autoentityNameToDataSourceName.TryAdd(autoentity.Key, this.DefaultDataSourceName); + } } // Process data source and entities information for each database in multiple database scenario. @@ -305,7 +334,8 @@ public RuntimeConfig( if (DataSourceFiles is not null && DataSourceFiles.SourceFiles is not null) { - IEnumerable> allEntities = Entities.AsEnumerable(); + IEnumerable>? allEntities = Entities?.AsEnumerable(); + IEnumerable>? allAutoentities = Autoentities?.AsEnumerable(); // Iterate through all the datasource files and load the config. IFileSystem fileSystem = new FileSystem(); // This loader is not used as a part of hot reload and therefore does not need a handler. @@ -322,7 +352,9 @@ public RuntimeConfig( { _dataSourceNameToDataSource = _dataSourceNameToDataSource.Concat(config._dataSourceNameToDataSource).ToDictionary(kvp => kvp.Key, kvp => kvp.Value); _entityNameToDataSourceName = _entityNameToDataSourceName.Concat(config._entityNameToDataSourceName).ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - allEntities = allEntities.Concat(config.Entities.AsEnumerable()); + _autoentityNameToDataSourceName = _autoentityNameToDataSourceName.Concat(config._autoentityNameToDataSourceName).ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + allEntities = allEntities?.Concat(config.Entities.AsEnumerable()); + allAutoentities = allAutoentities?.Concat(config.Autoentities.AsEnumerable()); } catch (Exception e) { @@ -336,11 +368,11 @@ public RuntimeConfig( } } - this.Entities = new RuntimeEntities(allEntities.ToDictionary(x => x.Key, x => x.Value)); + this.Entities = new RuntimeEntities(allEntities != null ? allEntities.ToDictionary(x => x.Key, x => x.Value) : new Dictionary()); + this.Autoentities = new RuntimeAutoentities(allAutoentities != null ? allAutoentities.ToDictionary(x => x.Key, x => x.Value) : new Dictionary()); } SetupDataSourcesUsed(); - } /// @@ -351,17 +383,19 @@ public RuntimeConfig( /// Default datasource. /// Runtime settings. /// Entities + /// Autoentities /// List of datasource files for multiple db scenario.Null for single db scenario. /// DefaultDataSourceName to maintain backward compatibility. /// Dictionary mapping datasourceName to datasource object. /// Dictionary mapping entityName to datasourceName. /// Datasource files which represent list of child runtimeconfigs for multi-db scenario. - public RuntimeConfig(string Schema, DataSource DataSource, RuntimeOptions Runtime, RuntimeEntities Entities, string DefaultDataSourceName, Dictionary DataSourceNameToDataSource, Dictionary EntityNameToDataSourceName, DataSourceFiles? DataSourceFiles = null, AzureKeyVaultOptions? AzureKeyVault = null) + public RuntimeConfig(string Schema, DataSource DataSource, RuntimeOptions Runtime, RuntimeEntities Entities, string DefaultDataSourceName, Dictionary DataSourceNameToDataSource, Dictionary EntityNameToDataSourceName, DataSourceFiles? DataSourceFiles = null, AzureKeyVaultOptions? AzureKeyVault = null, RuntimeAutoentities? Autoentities = null) { this.Schema = Schema; this.DataSource = DataSource; this.Runtime = Runtime; this.Entities = Entities; + this.Autoentities = Autoentities ?? new RuntimeAutoentities(new Dictionary()); this.DefaultDataSourceName = DefaultDataSourceName; _dataSourceNameToDataSource = DataSourceNameToDataSource; _entityNameToDataSourceName = EntityNameToDataSourceName; @@ -451,6 +485,24 @@ public DataSource GetDataSourceFromEntityName(string entityName) return _dataSourceNameToDataSource[_entityNameToDataSourceName[entityName]]; } + /// + /// Gets datasourceName from AutoentityNameToDatasourceName dictionary. + /// + /// autoentityName + /// DataSourceName + public string GetDataSourceNameFromAutoentityName(string autoentityName) + { + if (!_autoentityNameToDataSourceName.TryGetValue(autoentityName, out string? autoentityDataSource)) + { + throw new DataApiBuilderException( + message: $"{autoentityName} is not a valid autoentity.", + statusCode: HttpStatusCode.NotFound, + subStatusCode: DataApiBuilderException.SubStatusCodes.EntityNotFound); + } + + return autoentityDataSource; + } + /// /// Validates if datasource is present in runtimeConfig. /// @@ -476,12 +528,13 @@ Runtime is not null && Runtime.Host is not null /// /// Returns the ttl-seconds value for a given entity. - /// If the property is not set, returns the global default value set in the runtime config. - /// If the global default value is not set, the default value is used (5 seconds). + /// If the entity explicitly sets ttl-seconds, that value is used. + /// Otherwise, falls back to the global cache TTL setting. + /// Callers are responsible for checking whether caching is enabled before using the result. /// /// Name of the entity to check cache configuration. /// Number of seconds (ttl) that a cache entry should be valid before cache eviction. - /// Raised when an invalid entity name is provided or if the entity has caching disabled. + /// Raised when an invalid entity name is provided. public virtual int GetEntityCacheEntryTtl(string entityName) { if (!Entities.TryGetValue(entityName, out Entity? entityConfig)) @@ -492,31 +545,23 @@ public virtual int GetEntityCacheEntryTtl(string entityName) subStatusCode: DataApiBuilderException.SubStatusCodes.EntityNotFound); } - if (!entityConfig.IsCachingEnabled) - { - throw new DataApiBuilderException( - message: $"{entityName} does not have caching enabled.", - statusCode: HttpStatusCode.BadRequest, - subStatusCode: DataApiBuilderException.SubStatusCodes.NotSupported); - } - - if (entityConfig.Cache.UserProvidedTtlOptions) + if (entityConfig.Cache is not null && entityConfig.Cache.UserProvidedTtlOptions) { return entityConfig.Cache.TtlSeconds.Value; } - else - { - return GlobalCacheEntryTtl(); - } + + return GlobalCacheEntryTtl(); } /// /// Returns the cache level value for a given entity. - /// If the property is not set, returns the default (L1L2) for a given entity. + /// If the entity explicitly sets level, that value is used. + /// Otherwise, falls back to the global cache level or the default. + /// Callers are responsible for checking whether caching is enabled before using the result. /// /// Name of the entity to check cache configuration. /// Cache level that a cache entry should be stored in. - /// Raised when an invalid entity name is provided or if the entity has caching disabled. + /// Raised when an invalid entity name is provided. public virtual EntityCacheLevel GetEntityCacheEntryLevel(string entityName) { if (!Entities.TryGetValue(entityName, out Entity? entityConfig)) @@ -527,22 +572,12 @@ public virtual EntityCacheLevel GetEntityCacheEntryLevel(string entityName) subStatusCode: DataApiBuilderException.SubStatusCodes.EntityNotFound); } - if (!entityConfig.IsCachingEnabled) - { - throw new DataApiBuilderException( - message: $"{entityName} does not have caching enabled.", - statusCode: HttpStatusCode.BadRequest, - subStatusCode: DataApiBuilderException.SubStatusCodes.NotSupported); - } - - if (entityConfig.Cache.UserProvidedLevelOptions) + if (entityConfig.Cache is not null && entityConfig.Cache.UserProvidedLevelOptions) { return entityConfig.Cache.Level.Value; } - else - { - return EntityCacheLevel.L1L2; - } + + return EntityCacheOptions.DEFAULT_LEVEL; } /// diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 6f6c046651..525ea8d089 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -17,6 +17,7 @@ public record RuntimeOptions public RuntimeCacheOptions? Cache { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } + public CompressionOptions? Compression { get; init; } [JsonConstructor] public RuntimeOptions( @@ -28,7 +29,8 @@ public RuntimeOptions( TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, PaginationOptions? Pagination = null, - RuntimeHealthCheckConfig? Health = null) + RuntimeHealthCheckConfig? Health = null, + CompressionOptions? Compression = null) { this.Rest = Rest; this.GraphQL = GraphQL; @@ -39,6 +41,7 @@ public RuntimeOptions( this.Cache = Cache; this.Pagination = Pagination; this.Health = Health; + this.Compression = Compression; } /// diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index 9a54d09d8e..ae5c2dde95 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -320,6 +320,7 @@ public static JsonSerializerOptions GetSerializationOptions( options.Converters.Add(new EntityMcpOptionsConverterFactory()); options.Converters.Add(new RuntimeCacheOptionsConverterFactory()); options.Converters.Add(new RuntimeCacheLevel2OptionsConverterFactory()); + options.Converters.Add(new CompressionOptionsConverterFactory()); options.Converters.Add(new MultipleCreateOptionsConverter()); options.Converters.Add(new MultipleMutationOptionsConverter(options)); options.Converters.Add(new DataSourceConverterFactory(replacementSettings)); @@ -498,4 +499,9 @@ public void InsertWantedChangesInProductionMode() RuntimeConfig = runtimeConfigCopy; } } + + public void EditRuntimeConfig(RuntimeConfig newRuntimeConfig) + { + RuntimeConfig = newRuntimeConfig; + } } diff --git a/src/Core/Configurations/RuntimeConfigProvider.cs b/src/Core/Configurations/RuntimeConfigProvider.cs index b46a716f48..644782e2cc 100644 --- a/src/Core/Configurations/RuntimeConfigProvider.cs +++ b/src/Core/Configurations/RuntimeConfigProvider.cs @@ -411,4 +411,19 @@ private static RuntimeConfig HandleCosmosNoSqlConfiguration(string? schema, Runt return runtimeConfig; } + + public void AddMergedEntitiesToConfig(Dictionary newEntities) + { + Dictionary entities = new(_configLoader.RuntimeConfig!.Entities); + foreach ((string name, Entity entity) in newEntities) + { + entities.Add(name, entity); + } + + RuntimeConfig newRuntimeConfig = _configLoader.RuntimeConfig! with + { + Entities = new(entities) + }; + _configLoader.EditRuntimeConfig(newRuntimeConfig); + } } diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index ec97a48e4c..f3c9a4261f 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -49,6 +49,19 @@ public class RuntimeConfigValidator : IConfigValidator DatabaseType.DWSQL ]; + // Error messages for user-delegated authentication configuration. + public const string USER_DELEGATED_AUTH_DATABASE_TYPE_ERR_MSG = + "User-delegated authentication is only supported when data-source.database-type is 'mssql'."; + + public const string USER_DELEGATED_AUTH_MISSING_AUDIENCE_ERR_MSG = + "data-source.user-delegated-auth.database-audience must be set when user-delegated-auth is configured."; + + public const string USER_DELEGATED_AUTH_CACHING_ERR_MSG = + "runtime.cache.enabled must be false when user-delegated-auth is configured."; + + public const string USER_DELEGATED_AUTH_MISSING_CREDENTIALS_ERR_MSG = + "User-delegated authentication requires DAB_OBO_CLIENT_ID, DAB_OBO_TENANT_ID, and DAB_OBO_CLIENT_SECRET environment variables."; + // Error messages. public const string INVALID_CLAIMS_IN_POLICY_ERR_MSG = "One or more claim types supplied in the database policy are not supported."; @@ -83,18 +96,6 @@ public void ValidateConfigProperties() ValidateLoggerFilters(runtimeConfig); ValidateAzureLogAnalyticsAuth(runtimeConfig); ValidateFileSinkPath(runtimeConfig); - - // Running these graphQL validations only in development mode to ensure - // fast startup of engine in production mode. - if (runtimeConfig.IsDevelopmentMode()) - { - ValidateEntityConfiguration(runtimeConfig); - - if (runtimeConfig.IsGraphQLEnabled) - { - ValidateEntitiesDoNotGenerateDuplicateQueriesOrMutation(runtimeConfig.DataSource.DatabaseType, runtimeConfig.Entities); - } - } } /// @@ -119,6 +120,68 @@ public void ValidateDataSourceInConfig( } ValidateDatabaseType(runtimeConfig, fileSystem, logger); + + ValidateUserDelegatedAuthOptions(runtimeConfig); + } + + /// + /// Validates configuration for user-delegated authentication (OBO). + /// When any data source has user-delegated-auth configured, the following + /// rules are enforced: + /// - data-source.database-type must be "mssql". + /// - data-source.user-delegated-auth.database-audience must be present. + /// - runtime.cache.enabled must be false. + /// - Environment variables DAB_OBO_CLIENT_ID, DAB_OBO_TENANT_ID, and DAB_OBO_CLIENT_SECRET must be set. + /// + /// Runtime configuration. + private void ValidateUserDelegatedAuthOptions(RuntimeConfig runtimeConfig) + { + foreach (DataSource dataSource in runtimeConfig.ListAllDataSources()) + { + // Skip validation if user-delegated-auth is not configured or not enabled + if (dataSource.UserDelegatedAuth is null || !dataSource.UserDelegatedAuth.Enabled) + { + continue; + } + + if (dataSource.DatabaseType != DatabaseType.MSSQL) + { + HandleOrRecordException(new DataApiBuilderException( + message: USER_DELEGATED_AUTH_DATABASE_TYPE_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (string.IsNullOrWhiteSpace(dataSource.UserDelegatedAuth.DatabaseAudience)) + { + HandleOrRecordException(new DataApiBuilderException( + message: USER_DELEGATED_AUTH_MISSING_AUDIENCE_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // Validate OBO App Registration credentials are configured via environment variables. + string? clientId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR); + string? tenantId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR); + string? clientSecret = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR); + + if (string.IsNullOrWhiteSpace(clientId) || string.IsNullOrWhiteSpace(tenantId) || string.IsNullOrWhiteSpace(clientSecret)) + { + HandleOrRecordException(new DataApiBuilderException( + message: USER_DELEGATED_AUTH_MISSING_CREDENTIALS_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // Validate caching is disabled when user-delegated-auth is enabled + if (runtimeConfig.Runtime?.Cache?.Enabled == true) + { + HandleOrRecordException(new DataApiBuilderException( + message: USER_DELEGATED_AUTH_CACHING_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } } /// @@ -259,6 +322,11 @@ public async Task TryValidateConfig( _logger.LogInformation("Validating entity relationships."); ValidateRelationshipConfigCorrectness(runtimeConfig); + + // This function initializes the metadata providers which in turn validates the connectivity to the + // database and also validates all the REST and GraphQL paths as well as the permissions of the entities + // that are created from the 'Entities' and 'Autoentities' configuration, including the relationships defined in the config against the database metadata. + // Any exceptions caught during this process are added to the ConfigValidationExceptions list and logged at the end of this function. await ValidateEntitiesMetadata(runtimeConfig, loggerFactory); if (validationResult.IsValid && !ConfigValidationExceptions.Any()) @@ -411,6 +479,8 @@ public void ValidateRelationshipConfigCorrectness(RuntimeConfig runtimeConfig) /// This method validates the entities relationships against the database objects using /// metadata from the backend DB generated by this function. /// + /// NOTE: This function should not be used in the regular flow of DAB as we already initialize the metadata providers during startup, + /// doing it again will cause the application to fail as it will try to add data that is already present. public async Task ValidateEntitiesMetadata(RuntimeConfig runtimeConfig, ILoggerFactory loggerFactory) { // Only used for validation so we don't need the handler which is for hot reload scenarios. @@ -424,6 +494,7 @@ public async Task ValidateEntitiesMetadata(RuntimeConfig runtimeConfig, ILoggerF // Only used for validation so we don't need the handler which is for hot reload scenarios. MetadataProviderFactory metadataProviderFactory = new( runtimeConfigProvider: _runtimeConfigProvider, + runtimeConfigValidator: this, queryManagerFactory: queryManagerFactory, logger: loggerFactory.CreateLogger(), fileSystem: _fileSystem, @@ -656,6 +727,7 @@ private void ValidateRestMethods(Entity entity, string entityName) /// /// Helper method to validate that the rest path property for the entity is correctly configured. /// The rest path should not be null/empty and should not contain any reserved characters. + /// Allows sub-directories (forward slashes) in the path. /// /// Name of the entity. /// The rest path for the entity. @@ -672,10 +744,10 @@ private static void ValidateRestPathSettingsForEntity(string entityName, string ); } - if (RuntimeConfigValidatorUtil.DoesUriComponentContainReservedChars(pathForEntity)) + if (!RuntimeConfigValidatorUtil.TryValidateEntityRestPath(pathForEntity, out string? errorMessage)) { throw new DataApiBuilderException( - message: $"The rest path: {pathForEntity} for entity: {entityName} contains one or more reserved characters.", + message: $"The rest path: {pathForEntity} for entity: {entityName} {errorMessage ?? "contains invalid characters."}", statusCode: HttpStatusCode.ServiceUnavailable, subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError ); @@ -1518,4 +1590,26 @@ private static bool IsLoggerFilterValid(string loggerFilter) return false; } + + /// + /// Checks that all of the entities created with the Entities and Autoentities properties + /// are valid by having unique paths for both REST and GraphQL, that there are no duplicate + /// Queries or Mutation entities, and ensure the semantic correctness of all the entities. + /// + /// The runtime configuration. + public void ValidateEntityAndAutoentityConfigurations(RuntimeConfig runtimeConfig) + { + if (runtimeConfig.IsDevelopmentMode()) + { + ValidateEntityConfiguration(runtimeConfig); + + if (runtimeConfig.IsGraphQLEnabled) + { + ValidateEntitiesDoNotGenerateDuplicateQueriesOrMutation(runtimeConfig.DataSource.DatabaseType, runtimeConfig.Entities); + } + + // Running only in developer mode to ensure fast and smooth startup in production. + ValidatePermissionsInConfig(runtimeConfig); + } + } } diff --git a/src/Core/Configurations/RuntimeConfigValidatorUtil.cs b/src/Core/Configurations/RuntimeConfigValidatorUtil.cs index be742e586b..fce7821840 100644 --- a/src/Core/Configurations/RuntimeConfigValidatorUtil.cs +++ b/src/Core/Configurations/RuntimeConfigValidatorUtil.cs @@ -66,6 +66,94 @@ public static bool DoesUriComponentContainReservedChars(string uriComponent) return _reservedUriCharsRgx.IsMatch(uriComponent); } + /// + /// Method to validate an entity REST path allowing sub-directories (forward slashes). + /// Each segment of the path is validated for reserved characters and path traversal patterns. + /// + /// The entity REST path to validate. + /// Output parameter containing a specific error message if validation fails. + /// true if the path is valid, false otherwise. + public static bool TryValidateEntityRestPath(string entityRestPath, out string? errorMessage) + { + errorMessage = null; + + // Check for maximum path length (reasonable limit for URL paths) + const int MAX_PATH_LENGTH = 2048; + if (entityRestPath.Length > MAX_PATH_LENGTH) + { + errorMessage = $"exceeds maximum allowed length of {MAX_PATH_LENGTH} characters."; + return false; + } + + // Check for backslash usage - common mistake + if (entityRestPath.Contains('\\')) + { + errorMessage = "contains a backslash (\\). Use forward slash (/) for path separators."; + return false; + } + + // Check for percent-encoded characters (URL encoding not allowed in config) + if (entityRestPath.Contains('%')) + { + errorMessage = "contains percent-encoding (%) which is not allowed. Use literal characters only."; + return false; + } + + // Check for whitespace + if (entityRestPath.Any(char.IsWhiteSpace)) + { + errorMessage = "contains whitespace which is not allowed in URL paths."; + return false; + } + + // Split the path by '/' to validate each segment separately + string[] segments = entityRestPath.Split('/'); + + // Validate each segment doesn't contain reserved characters + foreach (string segment in segments) + { + if (string.IsNullOrEmpty(segment)) + { + errorMessage = "contains empty path segments. Ensure there are no leading, consecutive, or trailing slashes."; + return false; + } + + // Check for path traversal patterns + if (segment == "." || segment == "..") + { + errorMessage = "contains path traversal patterns ('.' or '..') which are not allowed."; + return false; + } + + // Check for specific reserved characters and provide helpful messages + if (segment.Contains('?')) + { + errorMessage = "contains '?' which is reserved for query strings in URLs."; + return false; + } + + if (segment.Contains('#')) + { + errorMessage = "contains '#' which is reserved for URL fragments."; + return false; + } + + if (segment.Contains(':')) + { + errorMessage = "contains ':' which is a reserved character and not allowed in URL paths."; + return false; + } + + if (_reservedUriCharsRgx.IsMatch(segment)) + { + errorMessage = "contains reserved characters that are not allowed in URL paths."; + return false; + } + } + + return true; + } + /// /// Method to validate if the TTL passed by the user is valid /// diff --git a/src/Core/Models/GraphQLFilterParsers.cs b/src/Core/Models/GraphQLFilterParsers.cs index 153def832f..90deb884b3 100644 --- a/src/Core/Models/GraphQLFilterParsers.cs +++ b/src/Core/Models/GraphQLFilterParsers.cs @@ -227,6 +227,7 @@ public Predicate Parse( cosmosQueryStructure.DatabaseObject.Name = sourceName; cosmosQueryStructure.SourceAlias = sourceAlias; + cosmosQueryStructure.EntityName = entityName; } } } diff --git a/src/Core/Models/RestRequestContexts/RestRequestContext.cs b/src/Core/Models/RestRequestContexts/RestRequestContext.cs index 70d6a371b5..e9987730a0 100644 --- a/src/Core/Models/RestRequestContexts/RestRequestContext.cs +++ b/src/Core/Models/RestRequestContexts/RestRequestContext.cs @@ -77,6 +77,12 @@ protected RestRequestContext(string entityName, DatabaseObject dbo) /// public NameValueCollection ParsedQueryString { get; set; } = new(); + /// + /// Raw query string from the HTTP request (URL-encoded). + /// Used to preserve encoding for special characters in query parameters. + /// + public string RawQueryString { get; set; } = string.Empty; + /// /// String holds information needed for pagination. /// Based on request this property may or may not be populated. diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 6402ce4ecb..081018e820 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -113,14 +113,32 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv context.FieldsToBeReturned = context.ParsedQueryString[key]!.Split(",").ToList(); break; case FILTER_URL: - // save the AST that represents the filter for the query - // ?$filter= - string filterQueryString = $"?{FILTER_URL}={context.ParsedQueryString[key]}"; - context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}"); + // Use raw (URL-encoded) filter value to preserve special characters like & + string? rawFilterValue = ExtractRawQueryParameter(context.RawQueryString, FILTER_URL); + // If key exists in ParsedQueryString but not in RawQueryString, something is wrong + if (rawFilterValue is null) + { + throw new DataApiBuilderException( + message: $"Unable to extract {FILTER_URL} parameter from query string.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause($"?{FILTER_URL}={rawFilterValue}", $"{context.EntityName}.{context.DatabaseObject.FullName}"); break; case SORT_URL: - string sortQueryString = $"?{SORT_URL}={context.ParsedQueryString[key]}"; - (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = GenerateOrderByLists(context, sqlMetadataProvider, sortQueryString); + // Use raw (URL-encoded) orderby value to preserve special characters + string? rawSortValue = ExtractRawQueryParameter(context.RawQueryString, SORT_URL); + // If key exists in ParsedQueryString but not in RawQueryString, something is wrong + if (rawSortValue is null) + { + throw new DataApiBuilderException( + message: $"Unable to extract {SORT_URL} parameter from query string.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = GenerateOrderByLists(context, sqlMetadataProvider, $"?{SORT_URL}={rawSortValue}"); break; case AFTER_URL: context.After = context.ParsedQueryString[key]; @@ -283,5 +301,38 @@ private static bool IsNull(string value) { return string.IsNullOrWhiteSpace(value) || string.Equals(value, "null", StringComparison.OrdinalIgnoreCase); } + + /// + /// Extracts the raw (URL-encoded) value of a query parameter from a query string. + /// Preserves special characters like & in filter values (e.g., %26 stays as %26). + /// + /// IMPORTANT: This method assumes the input queryString is a raw, URL-encoded query string + /// where special characters in parameter values are encoded (e.g., & is %26, space is %20). + /// It splits on unencoded '&' characters which are parameter separators in the URL standard. + /// If the queryString has already been decoded, this method will not work correctly. + /// + /// Raw URL-encoded query string (e.g., "?$filter=title%20eq%20%27A%26B%27") + /// The parameter name to extract (e.g., "$filter") + /// The raw encoded value of the parameter, or null if not found + internal static string? ExtractRawQueryParameter(string queryString, string parameterName) + { + if (string.IsNullOrWhiteSpace(queryString)) + { + return null; + } + + // Split on '&' which are parameter separators in properly URL-encoded query strings. + // Any '&' characters within parameter values will be encoded as %26. + foreach (string param in queryString.TrimStart('?').Split('&')) + { + int idx = param.IndexOf('='); + if (idx >= 0 && param.Substring(0, idx).Equals(parameterName, StringComparison.OrdinalIgnoreCase)) + { + return idx < param.Length - 1 ? param.Substring(idx + 1) : string.Empty; + } + } + + return null; + } } } diff --git a/src/Core/Resolvers/Factories/QueryManagerFactory.cs b/src/Core/Resolvers/Factories/QueryManagerFactory.cs index 72c99124c0..68896318d5 100644 --- a/src/Core/Resolvers/Factories/QueryManagerFactory.cs +++ b/src/Core/Resolvers/Factories/QueryManagerFactory.cs @@ -26,6 +26,7 @@ public class QueryManagerFactory : IAbstractQueryManagerFactory private readonly ILogger _logger; private readonly IHttpContextAccessor _contextAccessor; private readonly HotReloadEventHandler? _handler; + private readonly IOboTokenProvider? _oboTokenProvider; /// /// Initiates an instance of QueryManagerFactory @@ -33,17 +34,20 @@ public class QueryManagerFactory : IAbstractQueryManagerFactory /// runtimeconfigprovider. /// logger. /// httpcontextaccessor. + /// Optional OBO token provider for user-delegated authentication. public QueryManagerFactory( RuntimeConfigProvider runtimeConfigProvider, ILogger logger, IHttpContextAccessor contextAccessor, - HotReloadEventHandler? handler) + HotReloadEventHandler? handler, + IOboTokenProvider? oboTokenProvider = null) { handler?.Subscribe(QUERY_MANAGER_FACTORY_ON_CONFIG_CHANGED, OnConfigChanged); _handler = handler; _runtimeConfigProvider = runtimeConfigProvider; _logger = logger; _contextAccessor = contextAccessor; + _oboTokenProvider = oboTokenProvider; _queryBuilders = new Dictionary(); _queryExecutors = new Dictionary(); _dbExceptionsParsers = new Dictionary(); @@ -73,7 +77,7 @@ private void ConfigureQueryManagerFactory() case DatabaseType.MSSQL: queryBuilder = new MsSqlQueryBuilder(); exceptionParser = new MsSqlDbExceptionParser(_runtimeConfigProvider); - queryExecutor = new MsSqlQueryExecutor(_runtimeConfigProvider, exceptionParser, _logger, _contextAccessor, _handler); + queryExecutor = new MsSqlQueryExecutor(_runtimeConfigProvider, exceptionParser, _logger, _contextAccessor, _handler, _oboTokenProvider); break; case DatabaseType.MySQL: queryBuilder = new MySqlQueryBuilder(); @@ -88,7 +92,7 @@ private void ConfigureQueryManagerFactory() case DatabaseType.DWSQL: queryBuilder = new DwSqlQueryBuilder(enableNto1JoinOpt: _runtimeConfigProvider.GetConfig().EnableDwNto1JoinOpt); exceptionParser = new MsSqlDbExceptionParser(_runtimeConfigProvider); - queryExecutor = new MsSqlQueryExecutor(_runtimeConfigProvider, exceptionParser, _logger, _contextAccessor, _handler); + queryExecutor = new MsSqlQueryExecutor(_runtimeConfigProvider, exceptionParser, _logger, _contextAccessor, _handler, _oboTokenProvider); break; default: throw new NotSupportedException(dataSource.DatabaseTypeNotSupportedMessage); diff --git a/src/Core/Resolvers/IMsalClientWrapper.cs b/src/Core/Resolvers/IMsalClientWrapper.cs new file mode 100644 index 0000000000..78c59555bd --- /dev/null +++ b/src/Core/Resolvers/IMsalClientWrapper.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Identity.Client; + +namespace Azure.DataApiBuilder.Core.Resolvers; + +/// +/// Wrapper interface for MSAL confidential client operations. +/// This abstraction enables unit testing by allowing mocking of MSAL's sealed classes. +/// +public interface IMsalClientWrapper +{ + /// + /// Acquires a token on behalf of a user using the OBO flow. + /// + /// The scopes to request. + /// The user assertion (incoming JWT). + /// Cancellation token. + /// The authentication result containing the access token. + Task AcquireTokenOnBehalfOfAsync( + string[] scopes, + string userAssertion, + CancellationToken cancellationToken = default); +} diff --git a/src/Core/Resolvers/IOboTokenProvider.cs b/src/Core/Resolvers/IOboTokenProvider.cs new file mode 100644 index 0000000000..e921901f32 --- /dev/null +++ b/src/Core/Resolvers/IOboTokenProvider.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Security.Claims; + +namespace Azure.DataApiBuilder.Core.Resolvers; + +/// +/// Provides database access tokens acquired using On-Behalf-Of (OBO) flow +/// for user-delegated authentication scenarios. +/// +public interface IOboTokenProvider +{ + /// + /// Acquires a database access token on behalf of the authenticated user. + /// Uses in-memory caching with early refresh to minimize latency and + /// avoid expired tokens during active requests. + /// + /// The authenticated user's claims principal from JWT validation. + /// The incoming JWT token to use as the OBO assertion. + /// The target database audience (e.g., https://database.windows.net/). + /// Cancellation token. + /// The database access token string, or null if the principal or JWT assertion is null/empty. + /// + /// Thrown when required identity claims (oid/sub or tid) are missing from the principal, + /// or when MSAL token acquisition fails. + /// + Task GetAccessTokenOnBehalfOfAsync( + ClaimsPrincipal principal, + string incomingJwtAssertion, + string databaseAudience, + CancellationToken cancellationToken = default); +} diff --git a/src/Core/Resolvers/IQueryBuilder.cs b/src/Core/Resolvers/IQueryBuilder.cs index 3a14e472cd..f6a6222ddc 100644 --- a/src/Core/Resolvers/IQueryBuilder.cs +++ b/src/Core/Resolvers/IQueryBuilder.cs @@ -90,5 +90,7 @@ public interface IQueryBuilder /// DB Connection Param. /// public string QuoteTableNameAsDBConnectionParam(string param); + + public string BuildGetAutoentitiesQuery() => throw new NotSupportedException($"{GetType().Name} does not support Autoentities yet."); } } diff --git a/src/Core/Resolvers/MsSqlQueryBuilder.cs b/src/Core/Resolvers/MsSqlQueryBuilder.cs index 798c20975a..7adedc64d4 100644 --- a/src/Core/Resolvers/MsSqlQueryBuilder.cs +++ b/src/Core/Resolvers/MsSqlQueryBuilder.cs @@ -506,7 +506,7 @@ public string BuildStoredProcedureResultDetailsQuery(string databaseObjectName) /// 2. are computed based on other columns, /// are considered as read only columns. The query combines both the types of read-only columns and returns the list. /// - /// Param name of the schema/database. + /// Param name of the schema. /// Param name of the table. /// public string BuildQueryToGetReadOnlyColumns(string schemaParamName, string tableParamName) @@ -560,5 +560,188 @@ protected override string BuildPredicates(SqlQueryStructure structure) // contains LIKE and add the ESCAPE clause accordingly. return AddEscapeToLikeClauses(predicates); } + + /// + /// Builds the query used to get the list of tables with the SQL LIKE + /// syntax that will be transformed into entities. + /// NOTE: Currently this query only returns Tables, support for Views will come later. + /// + /// Pattern for tables that will be included. + /// Pattern for tables that will be excluded. + /// Pattern for naming the entities. + public string BuildGetAutoentitiesQuery() + { + string query = @$" + DECLARE @exclude_invalid_types BIT = 1; + + SET NOCOUNT ON; + + WITH + {IncludeAndExcludeSplitQuery(true)}, + {IncludeAndExcludeSplitQuery(false)}, + all_tables AS + ( + SELECT + s.name AS schema_name, + t.name AS object_name, + s.name + N'.' + t.name AS full_name, + N'table' AS object_type, + t.object_id + FROM sys.tables AS t + JOIN sys.schemas AS s + ON t.schema_id = s.schema_id + WHERE EXISTS + ( + SELECT 1 + FROM sys.key_constraints AS kc + WHERE kc.parent_object_id = t.object_id + AND kc.type = 'PK' + ) + ), + eligible_tables AS + ( + SELECT + o.schema_name, + o.object_name, + o.full_name, + o.object_type, + o.object_id, + CASE + WHEN so.is_ms_shipped = 1 THEN 1 + WHEN o.schema_name IN (N'sys', N'INFORMATION_SCHEMA') THEN 1 + WHEN o.object_name IN + ( + N'__EFMigrationsHistory', + N'__MigrationHistory', + N'__FlywayHistory', + N'sysdiagrams' + ) THEN 1 + WHEN o.object_name LIKE N'service_broker_%' THEN 1 + WHEN o.object_name LIKE N'queue_messages_%' THEN 1 + WHEN o.object_name LIKE N'MSmerge_%' THEN 1 + WHEN o.object_name LIKE N'MSreplication_%' THEN 1 + WHEN o.object_name LIKE N'FileTableUpdates$%' THEN 1 + WHEN o.object_name LIKE N'graph_%' THEN 1 + WHEN EXISTS + ( + SELECT 1 + FROM sys.tables AS t + WHERE t.object_id = o.object_id + AND + ( + t.is_tracked_by_cdc = 1 + OR t.temporal_type > 0 + OR t.is_filetable = 1 + OR t.is_memory_optimized = 1 + ) + ) THEN 1 + ELSE 0 + END AS is_system_object + FROM all_tables AS o + JOIN sys.objects AS so + ON so.object_id = o.object_id + ) + SELECT + a.schema_name AS [schema], + a.object_name AS [object], + CASE + WHEN LTRIM(RTRIM(ISNULL(@name_pattern, N''))) = N'' THEN a.object_name + ELSE REPLACE( + REPLACE(@name_pattern, N'{{schema}}', a.schema_name), + N'{{object}}', a.object_name + ) + END AS entity_name, + CASE + WHEN EXISTS + ( + SELECT 1 + FROM sys.columns AS c + JOIN sys.types AS ty + ON c.user_type_id = ty.user_type_id + WHERE c.object_id = a.object_id + AND ty.name IN + ( + N'geography', + N'geometry', + N'hierarchyid', + N'sql_variant', + N'xml', + N'rowversion', + N'vector' + ) + ) THEN 1 + ELSE 0 + END AS contains_invalid_types + FROM eligible_tables AS a + WHERE + a.is_system_object = 0 + AND + ( + NOT EXISTS (SELECT 1 FROM exclude_patterns) + OR NOT EXISTS + ( + SELECT 1 + FROM exclude_patterns AS ep + WHERE a.full_name LIKE ep.pattern COLLATE DATABASE_DEFAULT ESCAPE '\' + ) + ) + AND + ( + NOT EXISTS (SELECT 1 FROM include_patterns) + OR EXISTS + ( + SELECT 1 + FROM include_patterns AS ip + WHERE a.full_name LIKE ip.pattern COLLATE DATABASE_DEFAULT ESCAPE '\' + ) + ) + AND + ( + @exclude_invalid_types = 0 + OR NOT EXISTS + ( + SELECT 1 + FROM sys.columns AS c + JOIN sys.types AS ty + ON c.user_type_id = ty.user_type_id + WHERE c.object_id = a.object_id + AND ty.name IN + ( + N'geography', + N'geometry', + N'hierarchyid', + N'sql_variant', + N'xml', + N'rowversion', + N'vector' + ) + ) + ) + ORDER BY + a.schema_name, + a.object_name;"; + + return query; + } + + /// + /// Generates a SQL query segment for splitting include or exclude patterns. + /// + /// Indicates whether to generate the include or exclude pattern query. + /// An SQL query segment as a string. + public static string IncludeAndExcludeSplitQuery(bool isInclude) + { + string pattern = isInclude ? "include" : "exclude"; + + string query = $@" + {pattern}_patterns AS + ( + SELECT LTRIM(RTRIM(value)) AS pattern + FROM STRING_SPLIT(ISNULL(@{pattern}_pattern, N''), N',') + WHERE LTRIM(RTRIM(value)) <> N'' + )"; + + return query; + } } } diff --git a/src/Core/Resolvers/MsSqlQueryExecutor.cs b/src/Core/Resolvers/MsSqlQueryExecutor.cs index 5cbe9f6a76..368e5d6b00 100644 --- a/src/Core/Resolvers/MsSqlQueryExecutor.cs +++ b/src/Core/Resolvers/MsSqlQueryExecutor.cs @@ -4,6 +4,7 @@ using System.Data; using System.Data.Common; using System.Net; +using System.Security.Claims; using System.Text; using Azure.Core; using Azure.DataApiBuilder.Config; @@ -11,6 +12,7 @@ using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Product; using Azure.DataApiBuilder.Service.Exceptions; using Azure.Identity; using Microsoft.AspNetCore.Http; @@ -62,6 +64,24 @@ public override IDictionary ConnectionStringB /// private Dictionary _dataSourceToSessionContextUsage; + /// + /// DatasourceName to UserDelegatedAuthOptions for user-delegated authentication. + /// Only populated for data sources with user-delegated-auth enabled. + /// + private Dictionary _dataSourceUserDelegatedAuth; + + /// + /// DatasourceName to base Application Name for OBO per-user pooling. + /// Only populated for data sources with user-delegated-auth enabled. + /// Used as a prefix when constructing user-specific Application Names. + /// + private Dictionary _dataSourceBaseAppName; + + /// + /// Optional OBO token provider for user-delegated authentication. + /// + private readonly IOboTokenProvider? _oboTokenProvider; + private readonly RuntimeConfigProvider _runtimeConfigProvider; private const string QUERYIDHEADER = "QueryIdentifyingIds"; @@ -71,7 +91,8 @@ public MsSqlQueryExecutor( DbExceptionParser dbExceptionParser, ILogger logger, IHttpContextAccessor httpContextAccessor, - HotReloadEventHandler? handler = null) + HotReloadEventHandler? handler = null, + IOboTokenProvider? oboTokenProvider = null) : base(dbExceptionParser, logger, runtimeConfigProvider, @@ -80,9 +101,12 @@ public MsSqlQueryExecutor( { _dataSourceAccessTokenUsage = new Dictionary(); _dataSourceToSessionContextUsage = new Dictionary(); + _dataSourceUserDelegatedAuth = new Dictionary(); + _dataSourceBaseAppName = new Dictionary(); _accessTokensFromConfiguration = runtimeConfigProvider.ManagedIdentityAccessToken; _runtimeConfigProvider = runtimeConfigProvider; - ConfigureMsSqlQueryEecutor(); + _oboTokenProvider = oboTokenProvider; + ConfigureMsSqlQueryExecutor(); } /// @@ -99,9 +123,11 @@ public override SqlConnection CreateConnection(string dataSourceName) throw new DataApiBuilderException("Query execution failed. Could not find datasource to execute query against", HttpStatusCode.BadRequest, DataApiBuilderException.SubStatusCodes.DataSourceNotFound); } + string connectionString = GetConnectionStringForCurrentUser(dataSourceName); + SqlConnection conn = new() { - ConnectionString = ConnectionStringBuilders[dataSourceName].ConnectionString, + ConnectionString = connectionString, }; // Extract info message from SQLConnection @@ -135,10 +161,140 @@ public override SqlConnection CreateConnection(string dataSourceName) return conn; } + /// + /// Gets the connection string for the current user. For OBO-enabled data sources, + /// this returns a connection string with a user-specific Application Name to isolate + /// connection pools per user identity. + /// + /// The name of the data source. + /// The connection string to use for the current request. + private string GetConnectionStringForCurrentUser(string dataSourceName) + { + string baseConnectionString = ConnectionStringBuilders[dataSourceName].ConnectionString; + + // Per-user pooling is automatic when OBO is enabled. + // _dataSourceBaseAppName is only populated for data sources with user-delegated-auth enabled. + if (!_dataSourceBaseAppName.TryGetValue(dataSourceName, out string? baseAppName)) + { + // OBO not enabled for this data source, use the standard connection string + return baseConnectionString; + } + + // Extract user pool key from current HTTP context (prefers oid, falls back to sub) + string? poolKeyHash = GetUserPoolKeyHash(dataSourceName); + if (string.IsNullOrEmpty(poolKeyHash)) + { + // For OBO-enabled data sources, we must have a user context for actual requests. + // Null poolKeyHash is only acceptable during startup/metadata phase when there's no HttpContext. + // If we have an HttpContext with a User but missing required claims, fail-safe to prevent + // potential cross-user connection pool contamination. + if (HttpContextAccessor?.HttpContext?.User?.Identity?.IsAuthenticated == true) + { + throw new DataApiBuilderException( + message: "User-delegated authentication requires 'iss' and user identifier (oid/sub) claims for connection pool isolation.", + statusCode: System.Net.HttpStatusCode.Unauthorized, + subStatusCode: DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure); + } + + // No user context (startup/metadata phase), use base connection string + return baseConnectionString; + } + + // Create a user-specific connection string with per-user pool isolation. + // Format: {hash}|{user-custom-appname} where hash is placed FIRST to ensure it's never truncated. + // SQL Server limits Application Name to 128 characters. By placing the hash first, we guarantee + // per-user pool isolation even if the user's custom app name gets truncated. + // The hash is a URL-safe Base64-encoded SHA256 hash (16 bytes = ~22 chars). + const int maxApplicationNameLength = 128; + string hashPrefix = $"{poolKeyHash}|"; + int allowedBaseAppNameLength = Math.Max(0, maxApplicationNameLength - hashPrefix.Length); + string effectiveBaseAppName = baseAppName.Length > allowedBaseAppNameLength + ? baseAppName[..allowedBaseAppNameLength] + : baseAppName; + + SqlConnectionStringBuilder userBuilder = new(baseConnectionString) + { + ApplicationName = $"{hashPrefix}{effectiveBaseAppName}" + }; + + return userBuilder.ConnectionString; + } + + /// + /// Generates a pool key hash from the current user's claims for OBO per-user pooling. + /// Uses iss|(oid||sub) to ensure each unique user identity gets its own connection pool. + /// Prefers 'oid' (stable GUID) but falls back to 'sub' for guest/B2B users. + /// + /// The data source name for logging purposes. + /// A URL-safe Base64-encoded hash, or null if no user context is available. + private string? GetUserPoolKeyHash(string dataSourceName) + { + if (HttpContextAccessor?.HttpContext?.User is null) + { + QueryExecutorLogger.LogDebug( + "Cannot create per-user pool key for data source {DataSourceName}: no HTTP context or user available.", + dataSourceName); + return null; + } + + ClaimsPrincipal user = HttpContextAccessor.HttpContext.User; + + // Extract issuer claim - required for tenant isolation and connection pool security. + // The "iss" claim must be present along with a user identifier (oid/sub) for per-user pooling. + // Callers are responsible for enforcing fail-safe behavior when claims are missing. + string? iss = user.FindFirst("iss")?.Value; + + // User identifier claim resolution (in priority order): + // 1. "oid" - Short claim name for object ID, used in Entra ID v2.0 tokens + // 2. Full URI form - "http://schemas.microsoft.com/identity/claims/objectidentifier" + // Used in Entra ID v1.0 tokens and some SAML-based flows + // 3. "sub" - Subject claim, unique per user per application. Used as fallback for + // guest/B2B users where oid may not be present or stable across tenants + // 4. ClaimTypes.NameIdentifier - .NET standard claim type (maps to various underlying claims) + // Acts as a last-resort fallback for non-Entra identity providers + string? userKey = user.FindFirst("oid")?.Value + ?? user.FindFirst("http://schemas.microsoft.com/identity/claims/objectidentifier")?.Value + ?? user.FindFirst("sub")?.Value + ?? user.FindFirst(ClaimTypes.NameIdentifier)?.Value; + + if (string.IsNullOrEmpty(iss) || string.IsNullOrEmpty(userKey)) + { + // Cannot create a pool key without both claims + QueryExecutorLogger.LogDebug( + "Cannot create per-user pool key for data source {DataSourceName}: missing {MissingClaim} claim.", + dataSourceName, + string.IsNullOrEmpty(iss) ? "iss" : "user identifier (oid/sub)"); + return null; + } + + // Create the pool key as iss|userKey and hash it to keep connection string small + string poolKey = $"{iss}|{userKey}"; + return HashPoolKey(poolKey); + } + + /// + /// Hashes the pool key using SHA256 truncated to 16 bytes for a compact, URL-safe identifier. + /// Uses SHA256 (SHA-2 family) with 128-bit truncation per Microsoft security requirements. + /// This produces a ~22 character hash (16 bytes Base64-encoded) that fits well within SQL Server's + /// 128-char Application Name limit while providing sufficient collision resistance. + /// + /// The pool key to hash (format: iss|oid or iss|sub). + /// A URL-safe Base64-encoded hash of the key (~22 characters). + private static string HashPoolKey(string key) + { + byte[] fullHash = System.Security.Cryptography.SHA256.HashData( + System.Text.Encoding.UTF8.GetBytes(key)); + // Truncate to 16 bytes (128 bits) per MS security requirements for SHA-2 family + return Convert.ToBase64String(fullHash, 0, 16) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + /// /// Configure during construction or a hot-reload scenario. /// - private void ConfigureMsSqlQueryEecutor() + private void ConfigureMsSqlQueryExecutor() { IEnumerable> mssqldbs = _runtimeConfigProvider.GetConfig().GetDataSourceNamesToDataSourcesIterator().Where(x => x.Value.DatabaseType is DatabaseType.MSSQL || x.Value.DatabaseType is DatabaseType.DWSQL); @@ -156,11 +312,25 @@ private void ConfigureMsSqlQueryEecutor() MsSqlOptions? msSqlOptions = dataSource.GetTypedOptions(); _dataSourceToSessionContextUsage[dataSourceName] = msSqlOptions is null ? false : msSqlOptions.SetSessionContext; _dataSourceAccessTokenUsage[dataSourceName] = ShouldManagedIdentityAccessBeAttempted(builder); + + // Track user-delegated authentication settings + if (dataSource.IsUserDelegatedAuthEnabled) + { + _dataSourceUserDelegatedAuth[dataSourceName] = dataSource.UserDelegatedAuth!; + + // Per-user pooling: Store the base Application Name for hash prefixing at connection time. + // We'll prepend the user's iss|oid (or iss|sub) hash to create isolated pools per user. + // Note: ApplicationName is typically already set by RuntimeConfigLoader (e.g., "CustomerApp,dab_oss_2.0.0") + // but we use GetDataApiBuilderUserAgent() as fallback for consistency. + // We respect the user's Pooling setting from the connection string. + _dataSourceBaseAppName[dataSourceName] = builder.ApplicationName ?? ProductInfo.GetDataApiBuilderUserAgent(); + } } } /// - /// Modifies the properties of the supplied connection to support managed identity access. + /// Modifies the properties of the supplied connection to support managed identity access + /// or user-delegated (OBO) authentication. /// In the case of MsSql, gets access token if deemed necessary and sets it on the connection. /// The supplied connection is assumed to already have the same connection string /// provided in the runtime configuration. @@ -175,13 +345,44 @@ public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection dataSourceName = ConfigProvider.GetConfig().DefaultDataSourceName; } + SqlConnection sqlConn = (SqlConnection)conn; + + // Check if user-delegated authentication is enabled for this data source + if (_dataSourceUserDelegatedAuth.TryGetValue(dataSourceName, out UserDelegatedAuthOptions? userDelegatedAuth)) + { + // Check if we're in an HTTP request context (not startup/metadata phase) + bool isInRequestContext = HttpContextAccessor?.HttpContext is not null; + + if (isInRequestContext) + { + // At runtime with an HTTP request - attempt OBO flow + // Note: DatabaseAudience is validated at startup by RuntimeConfigValidator + string? oboToken = await GetOboAccessTokenAsync(userDelegatedAuth.DatabaseAudience!); + if (oboToken is not null) + { + sqlConn.AccessToken = oboToken; + return; + } + + // OBO is enabled but we couldn't get a token (e.g., missing Bearer token in request) + // This is an error during request processing - we must not fall back to managed identity + throw new DataApiBuilderException( + message: DataApiBuilderException.OBO_MISSING_USER_CONTEXT, + statusCode: HttpStatusCode.Unauthorized, + subStatusCode: DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure); + } + + // At startup/metadata phase (no HTTP context) - fall through to use the configured + // connection string authentication (e.g., Managed Identity, SQL credentials, etc.) + // This allows DAB to read schema metadata at startup, while OBO is used for actual requests. + QueryExecutorLogger.LogDebug("No HTTP context available - using configured connection string authentication for startup/metadata operations."); + } + _dataSourceAccessTokenUsage.TryGetValue(dataSourceName, out bool setAccessToken); // Only attempt to get the access token if the connection string is in the appropriate format if (setAccessToken) { - SqlConnection sqlConn = (SqlConnection)conn; - // If the configuration controller provided a managed identity access token use that, // else use the default saved access token if still valid. // Get a new token only if the saved token is null or expired. @@ -198,6 +399,37 @@ public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection } } + /// + /// Acquires an access token using On-Behalf-Of (OBO) flow for user-delegated authentication. + /// + /// The target database audience. + /// The OBO access token, or null if OBO cannot be performed. + private async Task GetOboAccessTokenAsync(string databaseAudience) + { + if (_oboTokenProvider is null || HttpContextAccessor?.HttpContext is null) + { + return null; + } + + HttpContext httpContext = HttpContextAccessor.HttpContext; + ClaimsPrincipal? principal = httpContext.User; + + // Extract the incoming JWT assertion from the Authorization header + string? authHeader = httpContext.Request.Headers["Authorization"].FirstOrDefault(); + if (string.IsNullOrWhiteSpace(authHeader) || !authHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + QueryExecutorLogger.LogWarning(DataApiBuilderException.OBO_MISSING_USER_CONTEXT); + return null; + } + + string incomingJwt = authHeader.Substring("Bearer ".Length).Trim(); + + return await _oboTokenProvider.GetAccessTokenOnBehalfOfAsync( + principal!, + incomingJwt, + databaseAudience); + } + /// /// Determines if managed identity access should be attempted or not. /// It should only be attempted, diff --git a/src/Core/Resolvers/MsalClientWrapper.cs b/src/Core/Resolvers/MsalClientWrapper.cs new file mode 100644 index 0000000000..b5bc15d5d1 --- /dev/null +++ b/src/Core/Resolvers/MsalClientWrapper.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Identity.Client; + +namespace Azure.DataApiBuilder.Core.Resolvers; + +/// +/// Implementation of that wraps +/// for OBO token acquisition. +/// +public sealed class MsalClientWrapper : IMsalClientWrapper +{ + private readonly IConfidentialClientApplication _msalClient; + + /// + /// Initializes a new instance of the class. + /// + /// The MSAL confidential client application. + public MsalClientWrapper(IConfidentialClientApplication msalClient) + { + _msalClient = msalClient ?? throw new ArgumentNullException(nameof(msalClient)); + } + + /// + public async Task AcquireTokenOnBehalfOfAsync( + string[] scopes, + string userAssertion, + CancellationToken cancellationToken = default) + { + UserAssertion assertion = new(userAssertion); + + return await _msalClient + .AcquireTokenOnBehalfOf(scopes, assertion) + .ExecuteAsync(cancellationToken); + } +} diff --git a/src/Core/Resolvers/OboSqlTokenProvider.cs b/src/Core/Resolvers/OboSqlTokenProvider.cs new file mode 100644 index 0000000000..a4bbf05030 --- /dev/null +++ b/src/Core/Resolvers/OboSqlTokenProvider.cs @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Security.Claims; +using System.Security.Cryptography; +using System.Text; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Extensions.Logging; +using Microsoft.Identity.Client; +using ZiggyCreatures.Caching.Fusion; +using AuthenticationOptions = Azure.DataApiBuilder.Config.ObjectModel.AuthenticationOptions; + +namespace Azure.DataApiBuilder.Core.Resolvers; + +/// +/// Provides SQL access tokens acquired using On-Behalf-Of (OBO) flow +/// for user-delegated authentication against Microsoft Entra ID. +/// Uses FusionCache (L1 in-memory only) for token caching with automatic +/// expiration and eager refresh. +/// +public sealed class OboSqlTokenProvider : IOboTokenProvider +{ + private readonly IMsalClientWrapper _msalClient; + private readonly ILogger _logger; + private readonly IFusionCache _cache; + + /// + /// Cache key prefix for OBO tokens to isolate from other cached data. + /// + private const string CACHE_KEY_PREFIX = "obo:"; + + /// + /// Eager refresh threshold as a fraction of TTL. + /// At 0.85, a token cached for 60 minutes will be eagerly refreshed after 51 minutes. + /// + private const float EAGER_REFRESH_THRESHOLD = 0.85f; + + /// + /// Minimum buffer before token expiry to trigger a refresh (in minutes). + /// + private const int MIN_EARLY_REFRESH_MINUTES = 5; + + /// + /// Initializes a new instance of OboSqlTokenProvider. + /// + /// MSAL client wrapper for token acquisition. + /// Logger instance. + /// FusionCache instance for token caching (L1 in-memory only). + public OboSqlTokenProvider( + IMsalClientWrapper msalClient, + ILogger logger, + IFusionCache cache) + { + _msalClient = msalClient ?? throw new ArgumentNullException(nameof(msalClient)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + } + + /// + public async Task GetAccessTokenOnBehalfOfAsync( + ClaimsPrincipal principal, + string incomingJwtAssertion, + string databaseAudience, + CancellationToken cancellationToken = default) + { + if (principal is null) + { + _logger.LogWarning("Cannot acquire OBO token: ClaimsPrincipal is null."); + return null; + } + + if (string.IsNullOrWhiteSpace(incomingJwtAssertion)) + { + _logger.LogWarning("Cannot acquire OBO token: Incoming JWT assertion is null or empty."); + return null; + } + + // Extract identity claims + string? subjectId = ExtractSubjectId(principal); + if (string.IsNullOrWhiteSpace(subjectId)) + { + _logger.LogWarning("Cannot acquire OBO token: Neither 'oid' nor 'sub' claim found in token."); + throw new DataApiBuilderException( + message: DataApiBuilderException.OBO_IDENTITY_CLAIMS_MISSING, + statusCode: HttpStatusCode.Unauthorized, + subStatusCode: DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure); + } + + string? tenantId = principal.FindFirst("tid")?.Value; + if (string.IsNullOrWhiteSpace(tenantId)) + { + _logger.LogWarning("Cannot acquire OBO token: 'tid' (tenant id) claim not found or empty in token."); + throw new DataApiBuilderException( + message: DataApiBuilderException.OBO_TENANT_CLAIM_MISSING, + statusCode: HttpStatusCode.Unauthorized, + subStatusCode: DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure); + } + + string authContextHash = ComputeAuthorizationContextHash(principal); + string cacheKey = BuildCacheKey(subjectId, tenantId, authContextHash); + + try + { + string[] scopes = [$"{databaseAudience.TrimEnd('/')}/.default"]; + + // Track whether we had a cache hit for logging + bool wasCacheMiss = false; + + // Use FusionCache GetOrSetAsync with factory pattern + // The factory is only called on cache miss + string? accessToken = await _cache.GetOrSetAsync( + key: cacheKey, + factory: async (ctx, ct) => + { + wasCacheMiss = true; + _logger.LogInformation( + "OBO token cache MISS for subject {SubjectId} (tenant: {TenantId}). Acquiring new token from Azure AD.", + subjectId, + tenantId); + + AuthenticationResult result = await _msalClient.AcquireTokenOnBehalfOfAsync( + scopes, + incomingJwtAssertion, + ct); + + // Calculate TTL based on token expiry with early refresh buffer + TimeSpan tokenLifetime = result.ExpiresOn - DateTimeOffset.UtcNow; + TimeSpan cacheDuration = tokenLifetime - TimeSpan.FromMinutes(MIN_EARLY_REFRESH_MINUTES); + + // Ensure minimum cache duration of 1 minute + if (cacheDuration < TimeSpan.FromMinutes(1)) + { + cacheDuration = TimeSpan.FromMinutes(1); + } + + // Set the cache duration based on actual token expiry + ctx.Options.SetDuration(cacheDuration); + + // Enable eager refresh - token will be refreshed in background at threshold + ctx.Options.SetEagerRefresh(EAGER_REFRESH_THRESHOLD); + + // Ensure tokens stay in L1 only (no distributed cache for security) + ctx.Options.SetSkipDistributedCache(true, true); + + _logger.LogInformation( + "OBO token ACQUIRED for subject {SubjectId}. Expires: {ExpiresOn}, Cache TTL: {CacheDuration}.", + subjectId, + result.ExpiresOn, + cacheDuration); + + return result.AccessToken; + }, + token: cancellationToken); + + if (!string.IsNullOrEmpty(accessToken) && !wasCacheMiss) + { + _logger.LogInformation("OBO token cache HIT for subject {SubjectId}.", subjectId); + } + + return accessToken; + } + catch (MsalException ex) + { + _logger.LogError( + ex, + "Failed to acquire OBO token for subject {SubjectId}. Error: {ErrorCode} - {Message}", + subjectId, + ex.ErrorCode, + ex.Message); + throw new DataApiBuilderException( + message: DataApiBuilderException.OBO_TOKEN_ACQUISITION_FAILED, + statusCode: HttpStatusCode.Unauthorized, + subStatusCode: DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure, + innerException: ex); + } + } + + /// + /// Extracts the subject identifier from the principal. + /// Prefers 'oid' claim (object ID) over 'sub' claim. + /// + private static string? ExtractSubjectId(ClaimsPrincipal principal) + { + string? oid = principal.FindFirst("oid")?.Value; + if (!string.IsNullOrWhiteSpace(oid)) + { + return oid; + } + + return principal.FindFirst("sub")?.Value; + } + + /// + /// Builds a canonical representation of permission-affecting claims (roles and scopes) + /// and computes a SHA-512 hash for use in the cache key. + /// + private static string ComputeAuthorizationContextHash(ClaimsPrincipal principal) + { + List values = []; + + foreach (Claim claim in principal.Claims) + { + if (claim.Type.Equals(AuthenticationOptions.ROLE_CLAIM_TYPE, StringComparison.OrdinalIgnoreCase) || + claim.Type.Equals("scp", StringComparison.OrdinalIgnoreCase)) + { + string[] parts = claim.Value.Split( + [' ', ','], + StringSplitOptions.RemoveEmptyEntries); + + foreach (string part in parts) + { + values.Add(part); + } + } + } + + if (values.Count == 0) + { + return ComputeSha512Hex(string.Empty); + } + + values.Sort(StringComparer.OrdinalIgnoreCase); + string canonical = string.Join("|", values); + return ComputeSha512Hex(canonical); + } + + /// + /// Computes SHA-512 hash and returns as hex string. + /// + private static string ComputeSha512Hex(string input) + { + byte[] data = Encoding.UTF8.GetBytes(input ?? string.Empty); + byte[] hash = SHA512.HashData(data); + return Convert.ToHexString(hash); + } + + /// + /// Builds the cache key from subject, tenant, and authorization context hash. + /// Format: obo:subjectId+tenantId+authContextHash + /// + private static string BuildCacheKey(string subjectId, string tenantId, string authContextHash) + { + return $"{CACHE_KEY_PREFIX}{subjectId}{tenantId}{authContextHash}"; + } +} diff --git a/src/Core/Resolvers/SqlMutationEngine.cs b/src/Core/Resolvers/SqlMutationEngine.cs index dfc53449f8..69fefe4341 100644 --- a/src/Core/Resolvers/SqlMutationEngine.cs +++ b/src/Core/Resolvers/SqlMutationEngine.cs @@ -541,6 +541,84 @@ await queryExecutor.ExecuteQueryAsync( { if (context.OperationType is EntityActionOperation.Upsert || context.OperationType is EntityActionOperation.UpsertIncremental) { + // When no primary key values are provided (empty PrimaryKeyValuePairs), + // there is no row to look up for update. The upsert degenerates to a + // pure INSERT - execute it via the insert path so the mutation engine + // generates a correct INSERT statement instead of an UPDATE with an + // empty WHERE clause (WHERE 1 = 1) that would match every row. + if (context.PrimaryKeyValuePairs.Count == 0) + { + DbResultSetRow? insertResultRow = null; + + try + { + using (TransactionScope transactionScope = ConstructTransactionScopeBasedOnDbType(sqlMetadataProvider)) + { + insertResultRow = + await PerformMutationOperation( + entityName: context.EntityName, + operationType: EntityActionOperation.Insert, + parameters: parameters, + sqlMetadataProvider: sqlMetadataProvider); + + if (insertResultRow is null) + { + throw new DataApiBuilderException( + message: "An unexpected error occurred while trying to execute the query.", + statusCode: HttpStatusCode.InternalServerError, + subStatusCode: DataApiBuilderException.SubStatusCodes.UnexpectedError); + } + + if (insertResultRow.Columns.Count == 0) + { + throw new DataApiBuilderException( + message: "Could not insert row with given values.", + statusCode: HttpStatusCode.Forbidden, + subStatusCode: DataApiBuilderException.SubStatusCodes.DatabasePolicyFailure); + } + + if (isDatabasePolicyDefinedForReadAction) + { + FindRequestContext findRequestContext = ConstructFindRequestContext(context, insertResultRow, roleName, sqlMetadataProvider); + IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); + selectOperationResponse = await queryEngine.ExecuteAsync(findRequestContext); + } + + transactionScope.Complete(); + } + } + catch (TransactionException) + { + throw _dabExceptionWithTransactionErrorMessage; + } + + if (isReadPermissionConfiguredForRole && !isDatabasePolicyDefinedForReadAction) + { + IEnumerable allowedExposedColumns = _authorizationResolver.GetAllowedExposedColumns(context.EntityName, roleName, EntityActionOperation.Read); + foreach (string columnInResponse in insertResultRow.Columns.Keys) + { + if (!allowedExposedColumns.Contains(columnInResponse)) + { + insertResultRow.Columns.Remove(columnInResponse); + } + } + } + + string pkRouteForLocationHeader = isReadPermissionConfiguredForRole + ? SqlResponseHelpers.ConstructPrimaryKeyRoute(context, insertResultRow.Columns, sqlMetadataProvider) + : string.Empty; + + return SqlResponseHelpers.ConstructCreatedResultResponse( + insertResultRow.Columns, + selectOperationResponse, + pkRouteForLocationHeader, + isReadPermissionConfiguredForRole, + isDatabasePolicyDefinedForReadAction, + context.OperationType, + GetBaseRouteFromConfig(_runtimeConfigProvider.GetConfig()), + GetHttpContext()); + } + DbResultSet? upsertOperationResult; DbResultSetRow upsertOperationResultSetRow; diff --git a/src/Core/Services/MetadataProviders/CosmosSqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/CosmosSqlMetadataProvider.cs index 61ffeeab09..5b9b2f935a 100644 --- a/src/Core/Services/MetadataProviders/CosmosSqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/CosmosSqlMetadataProvider.cs @@ -55,7 +55,7 @@ public class CosmosSqlMetadataProvider : ISqlMetadataProvider public List SqlMetadataExceptions { get; private set; } = new(); - public CosmosSqlMetadataProvider(RuntimeConfigProvider runtimeConfigProvider, IFileSystem fileSystem) + public CosmosSqlMetadataProvider(RuntimeConfigProvider runtimeConfigProvider, RuntimeConfigValidator runtimeConfigValidator, IFileSystem fileSystem) { RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); _fileSystem = fileSystem; @@ -76,6 +76,7 @@ public CosmosSqlMetadataProvider(RuntimeConfigProvider runtimeConfigProvider, IF subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); } + runtimeConfigValidator.ValidateEntityAndAutoentityConfigurations(runtimeConfig); _cosmosDb = cosmosDb; ParseSchemaGraphQLDocument(); diff --git a/src/Core/Services/MetadataProviders/MetadataProviderFactory.cs b/src/Core/Services/MetadataProviders/MetadataProviderFactory.cs index 66112fce21..6fe20969ed 100644 --- a/src/Core/Services/MetadataProviders/MetadataProviderFactory.cs +++ b/src/Core/Services/MetadataProviders/MetadataProviderFactory.cs @@ -19,6 +19,7 @@ public class MetadataProviderFactory : IMetadataProviderFactory { private readonly IDictionary _metadataProviders; private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly RuntimeConfigValidator _runtimeConfigValidator; private readonly IAbstractQueryManagerFactory _queryManagerFactory; private readonly ILogger _logger; private readonly IFileSystem _fileSystem; @@ -26,6 +27,7 @@ public class MetadataProviderFactory : IMetadataProviderFactory public MetadataProviderFactory( RuntimeConfigProvider runtimeConfigProvider, + RuntimeConfigValidator runtimeConfigValidator, IAbstractQueryManagerFactory queryManagerFactory, ILogger logger, IFileSystem fileSystem, @@ -34,6 +36,7 @@ public MetadataProviderFactory( { handler?.Subscribe(METADATA_PROVIDER_FACTORY_ON_CONFIG_CHANGED, OnConfigChanged); _runtimeConfigProvider = runtimeConfigProvider; + _runtimeConfigValidator = runtimeConfigValidator; _queryManagerFactory = queryManagerFactory; _logger = logger; _fileSystem = fileSystem; @@ -48,11 +51,11 @@ private void ConfigureMetadataProviders() { ISqlMetadataProvider metadataProvider = dataSource.DatabaseType switch { - DatabaseType.CosmosDB_NoSQL => new CosmosSqlMetadataProvider(_runtimeConfigProvider, _fileSystem), - DatabaseType.MSSQL => new MsSqlMetadataProvider(_runtimeConfigProvider, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), - DatabaseType.DWSQL => new MsSqlMetadataProvider(_runtimeConfigProvider, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), - DatabaseType.PostgreSQL => new PostgreSqlMetadataProvider(_runtimeConfigProvider, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), - DatabaseType.MySQL => new MySqlMetadataProvider(_runtimeConfigProvider, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), + DatabaseType.CosmosDB_NoSQL => new CosmosSqlMetadataProvider(_runtimeConfigProvider, _runtimeConfigValidator, _fileSystem), + DatabaseType.MSSQL => new MsSqlMetadataProvider(_runtimeConfigProvider, _runtimeConfigValidator, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), + DatabaseType.DWSQL => new MsSqlMetadataProvider(_runtimeConfigProvider, _runtimeConfigValidator, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), + DatabaseType.PostgreSQL => new PostgreSqlMetadataProvider(_runtimeConfigProvider, _runtimeConfigValidator, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), + DatabaseType.MySQL => new MySqlMetadataProvider(_runtimeConfigProvider, _runtimeConfigValidator, _queryManagerFactory, _logger, dataSourceName, _isValidateOnly), _ => throw new NotSupportedException(dataSource.DatabaseTypeNotSupportedMessage), }; diff --git a/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs index 7d02798427..96fa47dcfd 100644 --- a/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs @@ -33,11 +33,12 @@ public class MsSqlMetadataProvider : public MsSqlMetadataProvider( RuntimeConfigProvider runtimeConfigProvider, + RuntimeConfigValidator runtimeConfigValidator, IAbstractQueryManagerFactory queryManagerFactory, ILogger logger, string dataSourceName, bool isValidateOnly = false) - : base(runtimeConfigProvider, queryManagerFactory, logger, dataSourceName, isValidateOnly) + : base(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory, logger, dataSourceName, isValidateOnly) { _runtimeConfigProvider = runtimeConfigProvider; } @@ -290,5 +291,122 @@ private bool TryResolveDbType(string sqlDbTypeName, out DbType dbType) return false; } } + + /// + protected override async Task GenerateAutoentitiesIntoEntities(IReadOnlyDictionary? autoentities) + { + if (autoentities is null) + { + return; + } + + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + Dictionary entities = new(); + foreach ((string autoentityName, Autoentity autoentity) in autoentities) + { + int addedEntities = 0; + JsonArray? resultArray = await QueryAutoentitiesAsync(autoentity); + if (resultArray is null) + { + continue; + } + + foreach (JsonObject? resultObject in resultArray) + { + if (resultObject is null) + { + throw new DataApiBuilderException( + message: $"Cannot create new entity from autoentity pattern due to an internal error.", + statusCode: HttpStatusCode.InternalServerError, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + // Extract the entity name, schema, and database object name from the query result. + // The SQL query returns these values with placeholders already replaced. + string? entityName = resultObject["entity_name"]?.ToString(); + string? objectName = resultObject["object"]?.ToString(); + string? schemaName = resultObject["schema"]?.ToString(); + + if (string.IsNullOrWhiteSpace(entityName) || string.IsNullOrWhiteSpace(objectName) || string.IsNullOrWhiteSpace(schemaName)) + { + _logger.LogError("Skipping autoentity generation: entity_name or object is null or empty for autoentity pattern '{AutoentityName}'.", autoentityName); + continue; + } + + // Create the entity using the template settings and permissions from the autoentity configuration. + // Currently the source type is always Table for auto-generated entities from database objects. + Entity generatedEntity = new( + Source: new EntitySource( + Object: objectName, + Type: EntitySourceType.Table, + Parameters: null, + KeyFields: null), + GraphQL: autoentity.Template.GraphQL, + Rest: autoentity.Template.Rest, + Mcp: autoentity.Template.Mcp, + Permissions: autoentity.Permissions, + Cache: autoentity.Template.Cache, + Health: autoentity.Template.Health, + Fields: null, + Relationships: null, + Mappings: new()); + + // Add the generated entity to the linking entities dictionary. + // This allows the entity to be processed later during metadata population. + if (!entities.TryAdd(entityName, generatedEntity) || !runtimeConfig.TryAddGeneratedAutoentityNameToDataSourceName(entityName, autoentityName)) + { + throw new DataApiBuilderException( + message: $"Entity with name '{entityName}' already exists. Cannot create new entity from autoentity pattern with definition-name '{autoentityName}'.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + if (runtimeConfig.IsRestEnabled) + { + _logger.LogInformation("[{entity}] REST path: {globalRestPath}/{entityRestPath}", entityName, runtimeConfig.RestPath, entityName); + } + else + { + _logger.LogInformation(message: "REST calls are disabled for the entity: {entity}", entityName); + } + + addedEntities++; + } + + if (addedEntities == 0) + { + _logger.LogWarning("No new entities were generated from the autoentity {autoentityName} defined in the configuration.", autoentityName); + } + } + + _runtimeConfigProvider.AddMergedEntitiesToConfig(entities); + } + + public async Task QueryAutoentitiesAsync(Autoentity autoentity) + { + string include = string.Join(",", autoentity.Patterns.Include); + string exclude = string.Join(",", autoentity.Patterns.Exclude); + string namePattern = autoentity.Patterns.Name; + string getAutoentitiesQuery = SqlQueryBuilder.BuildGetAutoentitiesQuery(); + Dictionary parameters = new() + { + { $"{BaseQueryStructure.PARAM_NAME_PREFIX}include_pattern", new(include, null, SqlDbType.NVarChar) }, + { $"{BaseQueryStructure.PARAM_NAME_PREFIX}exclude_pattern", new(exclude, null, SqlDbType.NVarChar) }, + { $"{BaseQueryStructure.PARAM_NAME_PREFIX}name_pattern", new(namePattern, null, SqlDbType.NVarChar) } + }; + + _logger.LogInformation("Query for Autoentities is being executed with the following parameters."); + _logger.LogInformation($"Autoentities include pattern: {include}"); + _logger.LogInformation($"Autoentities exclude pattern: {exclude}"); + _logger.LogInformation($"Autoentities name pattern: {namePattern}"); + + JsonArray? resultArray = await QueryExecutor.ExecuteQueryAsync( + sqltext: getAutoentitiesQuery, + parameters: parameters, + dataReaderHandler: QueryExecutor.GetJsonArrayAsync, + dataSourceName: _dataSourceName); + + return resultArray; + } } } diff --git a/src/Core/Services/MetadataProviders/MySqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/MySqlMetadataProvider.cs index 99336180d5..26098c2d15 100644 --- a/src/Core/Services/MetadataProviders/MySqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/MySqlMetadataProvider.cs @@ -23,11 +23,12 @@ public class MySqlMetadataProvider : SqlMetadataProvider logger, string dataSourceName, bool isValidateOnly = false) - : base(runtimeConfigProvider, queryManagerFactory, logger, dataSourceName, isValidateOnly) + : base(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory, logger, dataSourceName, isValidateOnly) { try { diff --git a/src/Core/Services/MetadataProviders/PostgreSqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/PostgreSqlMetadataProvider.cs index ecd65b3d95..0d43d0efbc 100644 --- a/src/Core/Services/MetadataProviders/PostgreSqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/PostgreSqlMetadataProvider.cs @@ -22,11 +22,12 @@ public class PostgreSqlMetadataProvider : public PostgreSqlMetadataProvider( RuntimeConfigProvider runtimeConfigProvider, + RuntimeConfigValidator runtimeConfigValidator, IAbstractQueryManagerFactory queryManagerFactory, ILogger logger, string dataSourceName, bool isValidateOnly = false) - : base(runtimeConfigProvider, queryManagerFactory, logger, dataSourceName, isValidateOnly) + : base(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory, logger, dataSourceName, isValidateOnly) { } diff --git a/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs index 8553e08136..6aa2712468 100644 --- a/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Collections.ObjectModel; using System.Data; using System.Data.Common; using System.Diagnostics.CodeAnalysis; @@ -38,14 +39,17 @@ public abstract class SqlMetadataProvider : private readonly DatabaseType _databaseType; - // Represents the entities exposed in the runtime config. - private IReadOnlyDictionary _entities; - // Represents the linking entities created by DAB to support multiple mutations for entities having an M:N relationship between them. protected Dictionary _linkingEntities = new(); protected readonly string _dataSourceName; + // Represents the entities exposed in the runtime config. + private IReadOnlyDictionary Entities => new ReadOnlyDictionary(_runtimeConfigProvider.GetConfig().Entities.Where(x => string.Equals(_runtimeConfigProvider.GetConfig().GetDataSourceNameFromEntityName(x.Key), _dataSourceName, StringComparison.OrdinalIgnoreCase)).ToDictionary(x => x.Key, x => x.Value)); + + // Represents the autoentities exposed in the runtime config. + private IReadOnlyDictionary Autoentities => new ReadOnlyDictionary(_runtimeConfigProvider.GetConfig().Autoentities.Where(x => string.Equals(_runtimeConfigProvider.GetConfig().GetDataSourceNameFromAutoentityName(x.Key), _dataSourceName, StringComparison.OrdinalIgnoreCase)).ToDictionary(x => x.Key, x => x.Value)); + // Dictionary containing mapping of graphQL stored procedure exposed query/mutation name // to their corresponding entity names defined in the config. public Dictionary GraphQLStoredProcedureExposedNameToEntityNameMap { get; set; } = new(); @@ -73,6 +77,8 @@ public abstract class SqlMetadataProvider : private RuntimeConfigProvider _runtimeConfigProvider; + private RuntimeConfigValidator _runtimeConfigValidator; + private Dictionary> EntityBackingColumnsToExposedNames { get; } = new(); private Dictionary> EntityExposedNamesToBackingColumnNames { get; } = new(); @@ -104,6 +110,7 @@ private void HandleOrRecordException(Exception e) public SqlMetadataProvider( RuntimeConfigProvider runtimeConfigProvider, + RuntimeConfigValidator runtimeConfigValidator, IAbstractQueryManagerFactory engineFactory, ILogger logger, string dataSourceName, @@ -111,12 +118,12 @@ public SqlMetadataProvider( { RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); _runtimeConfigProvider = runtimeConfigProvider; + _runtimeConfigValidator = runtimeConfigValidator; _dataSourceName = dataSourceName; _databaseType = runtimeConfig.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; - _entities = runtimeConfig.Entities.Where(x => string.Equals(runtimeConfig.GetDataSourceNameFromEntityName(x.Key), _dataSourceName, StringComparison.OrdinalIgnoreCase)).ToDictionary(x => x.Key, x => x.Value); _logger = logger; _isValidateOnly = isValidateOnly; - foreach ((string entityName, Entity entityMetatdata) in _entities) + foreach ((string entityName, Entity entityMetatdata) in Entities) { if (runtimeConfig.IsRestEnabled) { @@ -227,7 +234,7 @@ public bool TryGetExposedColumnName(string entityName, string backingFieldName, return true; } - if (_entities.TryGetValue(entityName, out Entity? entityDefinition) && entityDefinition.Fields is not null) + if (Entities.TryGetValue(entityName, out Entity? entityDefinition) && entityDefinition.Fields is not null) { // Find the field by backing name and use its Alias if present. FieldMetadata? matched = entityDefinition @@ -260,7 +267,7 @@ public bool TryGetBackingColumn(string entityName, string field, [NotNullWhen(tr return true; } - if (_entities.TryGetValue(entityName, out Entity? entityDefinition) && entityDefinition.Fields is not null) + if (Entities.TryGetValue(entityName, out Entity? entityDefinition) && entityDefinition.Fields is not null) { FieldMetadata? matchedField = entityDefinition.Fields.FirstOrDefault(f => f.Alias != null && f.Alias.Equals(field, StringComparison.OrdinalIgnoreCase)); @@ -284,12 +291,12 @@ public IReadOnlyDictionary GetEntityNamesAndDbObjects() /// public string GetEntityName(string graphQLType) { - if (_entities.ContainsKey(graphQLType)) + if (Entities.ContainsKey(graphQLType)) { return graphQLType; } - foreach ((string entityName, Entity entity) in _entities) + foreach ((string entityName, Entity entity) in Entities) { if (entity.GraphQL.Singular == graphQLType) { @@ -307,7 +314,7 @@ public string GetEntityName(string graphQLType) public async Task InitializeAsync() { System.Diagnostics.Stopwatch timer = System.Diagnostics.Stopwatch.StartNew(); - GenerateDatabaseObjectForEntities(); + if (_isValidateOnly) { // Currently Validate mode only support single datasource, @@ -324,8 +331,20 @@ public async Task InitializeAsync() } } + if (GetDatabaseType() == DatabaseType.MSSQL) + { + await GenerateAutoentitiesIntoEntities(Autoentities); + } + + // Running these entity validations only in development mode to ensure + // fast startup of engine in production mode. + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + _runtimeConfigValidator.ValidateEntityAndAutoentityConfigurations(runtimeConfig); + + GenerateDatabaseObjectForEntities(); await PopulateObjectDefinitionForEntities(); GenerateExposedToBackingColumnMapsForEntities(); + // When IsLateConfigured is true we are in a hosted scenario and do not reveal primary key information. if (!_runtimeConfigProvider.IsLateConfigured) { @@ -384,7 +403,7 @@ public bool TryGetBackingFieldToExposedFieldMap(string entityName, [NotNullWhen( private void LogPrimaryKeys() { ColumnDefinition column; - foreach ((string entityName, Entity _) in _entities) + foreach ((string entityName, Entity _) in Entities) { try { @@ -548,7 +567,7 @@ private void GenerateRestPathToEntityMap() RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); string graphQLGlobalPath = runtimeConfig.GraphQLPath; - foreach ((string entityName, Entity entity) in _entities) + foreach ((string entityName, Entity entity) in Entities) { try { @@ -680,12 +699,21 @@ protected virtual Dictionary private void GenerateDatabaseObjectForEntities() { Dictionary sourceObjects = new(); - foreach ((string entityName, Entity entity) in _entities) + foreach ((string entityName, Entity entity) in Entities) { PopulateDatabaseObjectForEntity(entity, entityName, sourceObjects); } } + /// + /// Creates entities for each table that is found, based on the autoentity configuration. + /// This method is only called for tables in MsSql. + /// + protected virtual Task GenerateAutoentitiesIntoEntities(IReadOnlyDictionary? autoentities) + { + throw new NotSupportedException($"{GetType().Name} does not support Autoentities yet."); + } + protected void PopulateDatabaseObjectForEntity( Entity entity, string entityName, @@ -810,7 +838,7 @@ private void ProcessRelationships( foreach ((string relationshipName, EntityRelationship relationship) in entity.Relationships!) { string targetEntityName = relationship.TargetEntity; - if (!_entities.TryGetValue(targetEntityName, out Entity? targetEntity)) + if (!Entities.TryGetValue(targetEntityName, out Entity? targetEntity)) { throw new InvalidOperationException($"Target Entity {targetEntityName} should be one of the exposed entities."); } @@ -1092,7 +1120,7 @@ public IReadOnlyDictionary GetLinkingEntities() /// private async Task PopulateObjectDefinitionForEntities() { - foreach ((string entityName, Entity entity) in _entities) + foreach ((string entityName, Entity entity) in Entities) { await PopulateObjectDefinitionForEntity(entityName, entity); } @@ -1291,7 +1319,7 @@ private async Task PopulateResultSetDefinitionsForStoredProcedureAsync( /// private void GenerateExposedToBackingColumnMapsForEntities() { - foreach ((string entityName, Entity _) in _entities) + foreach ((string entityName, Entity _) in Entities) { GenerateExposedToBackingColumnMapUtil(entityName); } @@ -1316,7 +1344,7 @@ private void GenerateExposedToBackingColumnMapUtil(string entityName) Dictionary exposedToBack = new(StringComparer.OrdinalIgnoreCase); // Pull definitions. - _entities.TryGetValue(entityName, out Entity? entity); + Entities.TryGetValue(entityName, out Entity? entity); SourceDefinition sourceDefinition = GetSourceDefinition(entityName); // 1) Prefer new-style fields (backing = f.Name, exposed = f.Alias ?? f.Name) @@ -1412,7 +1440,7 @@ private async Task PopulateSourceDefinitionAsync( subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); } - _entities.TryGetValue(entityName, out Entity? entity); + Entities.TryGetValue(entityName, out Entity? entity); if (GetDatabaseType() is DatabaseType.MSSQL && entity is not null && entity.Source.Type is EntitySourceType.Table) { await PopulateTriggerMetadataForTable(entityName, schemaName, tableName, sourceDefinition); diff --git a/src/Core/Services/OpenAPI/IOpenApiDocumentor.cs b/src/Core/Services/OpenAPI/IOpenApiDocumentor.cs index d24fd52c4a..4158b8e654 100644 --- a/src/Core/Services/OpenAPI/IOpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/IOpenApiDocumentor.cs @@ -13,11 +13,20 @@ public interface IOpenApiDocumentor { /// /// Attempts to return the OpenAPI description document, if generated. + /// Returns the superset of all roles' permissions. /// /// String representation of JSON OpenAPI description document. /// True (plus string representation of document), when document exists. False, otherwise. public bool TryGetDocument([NotNullWhen(true)] out string? document); + /// + /// Attempts to return a role-specific OpenAPI description document. + /// + /// The role name to filter permissions. + /// String representation of JSON OpenAPI description document. + /// True if role exists and document generated. False if role not found. + public bool TryGetDocumentForRole(string role, [NotNullWhen(true)] out string? document); + /// /// Creates an OpenAPI description document using OpenAPI.NET. /// Document compliant with patches of OpenAPI V3.0 spec 3.0.0 and 3.0.1, diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 69e009b55f..2eccf279bb 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Globalization; -using System.Linq; using System.Net; using System.Net.Mime; using System.Text; @@ -17,6 +17,7 @@ using Azure.DataApiBuilder.Core.Services.OpenAPI; using Azure.DataApiBuilder.Product; using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Extensions.Logging; using Microsoft.OpenApi.Any; using Microsoft.OpenApi.Models; using Microsoft.OpenApi.Writers; @@ -32,14 +33,17 @@ public class OpenApiDocumentor : IOpenApiDocumentor { private readonly IMetadataProviderFactory _metadataProviderFactory; private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly ILogger _logger; private OpenApiResponses _defaultOpenApiResponses; private OpenApiDocument? _openApiDocument; + private readonly ConcurrentDictionary _roleSpecificDocuments = new(StringComparer.OrdinalIgnoreCase); private const string DOCUMENTOR_UI_TITLE = "Data API builder - REST Endpoint"; private const string GETALL_DESCRIPTION = "Returns entities."; private const string GETONE_DESCRIPTION = "Returns an entity."; private const string POST_DESCRIPTION = "Create entity."; private const string PUT_DESCRIPTION = "Replace or create entity."; + private const string PUT_PATCH_KEYLESS_DESCRIPTION = "Create entity (keyless). For entities with auto-generated primary keys, creates a new record without requiring the key in the URL."; private const string PATCH_DESCRIPTION = "Update or create entity."; private const string DELETE_DESCRIPTION = "Delete entity."; private const string SP_EXECUTE_DESCRIPTION = "Executes a stored procedure."; @@ -63,19 +67,27 @@ public class OpenApiDocumentor : IOpenApiDocumentor /// /// Constructor denotes required services whose metadata is used to generate the OpenAPI description document. /// - /// Provides database object metadata. + /// Provides database object metadata. /// Provides entity/REST path metadata. - public OpenApiDocumentor(IMetadataProviderFactory metadataProviderFactory, RuntimeConfigProvider runtimeConfigProvider, HotReloadEventHandler? handler) + /// Hot reload event handler. + /// Logger for diagnostic information. + public OpenApiDocumentor( + IMetadataProviderFactory metadataProviderFactory, + RuntimeConfigProvider runtimeConfigProvider, + HotReloadEventHandler? handler, + ILogger logger) { handler?.Subscribe(DOCUMENTOR_ON_CONFIG_CHANGED, OnConfigChanged); _metadataProviderFactory = metadataProviderFactory; _runtimeConfigProvider = runtimeConfigProvider; + _logger = logger; _defaultOpenApiResponses = CreateDefaultOpenApiResponses(); } public void OnConfigChanged(object? sender, HotReloadEventArgs args) { CreateDocument(doOverrideExistingDocument: true); + _roleSpecificDocuments.Clear(); // Clear role-specific document cache on config change } /// @@ -102,6 +114,133 @@ public bool TryGetDocument([NotNullWhen(true)] out string? document) } } + /// + /// Attempts to return a role-specific OpenAPI description document. + /// + /// The role name to filter permissions (case-insensitive). + /// String representation of JSON OpenAPI description document. + /// True if role exists and document generated. False if role not found or empty/whitespace. + public bool TryGetDocumentForRole(string role, [NotNullWhen(true)] out string? document) + { + document = null; + + // Validate role is not null, empty, or whitespace + if (string.IsNullOrWhiteSpace(role)) + { + return false; + } + + // Check cache first + if (_roleSpecificDocuments.TryGetValue(role, out document)) + { + return true; + } + + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + + // Check if the role exists in any entity's permissions using LINQ + bool roleExists = runtimeConfig.Entities + .Any(kvp => kvp.Value.Permissions?.Any(p => string.Equals(p.Role, role, StringComparison.OrdinalIgnoreCase)) == true); + + if (!roleExists) + { + return false; + } + + try + { + OpenApiDocument? roleDoc = GenerateDocumentForRole(runtimeConfig, role); + if (roleDoc is null) + { + return false; + } + + using StringWriter textWriter = new(CultureInfo.InvariantCulture); + OpenApiJsonWriter jsonWriter = new(textWriter); + roleDoc.SerializeAsV3(jsonWriter); + document = textWriter.ToString(); + + // Cache the role-specific document + _roleSpecificDocuments.TryAdd(role, document); + + return true; + } + catch (Exception ex) + { + // Log exception details for debugging document generation failures + _logger.LogError(ex, "Failed to generate OpenAPI document for role '{Role}'", role); + return false; + } + } + + /// + /// Generates an OpenAPI document filtered for a specific role. + /// + private OpenApiDocument? GenerateDocumentForRole(RuntimeConfig runtimeConfig, string role) + { + string title = $"{DOCUMENTOR_UI_TITLE} - {role}"; + return BuildOpenApiDocument(runtimeConfig, role, title); + } + + /// + /// Builds an OpenAPI document with optional role-based filtering. + /// Shared logic for both superset and role-specific document generation. + /// + /// Runtime configuration. + /// Optional role to filter permissions. If null, returns superset of all roles. + /// Document title. + /// OpenAPI document. + private OpenApiDocument BuildOpenApiDocument(RuntimeConfig runtimeConfig, string? role, string title) + { + string restEndpointPath = runtimeConfig.RestPath; + string? runtimeBaseRoute = runtimeConfig.Runtime?.BaseRoute; + string url = string.IsNullOrEmpty(runtimeBaseRoute) ? restEndpointPath : runtimeBaseRoute + "/" + restEndpointPath; + + OpenApiComponents components = new() + { + Schemas = CreateComponentSchemas(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName, role, isRequestBodyStrict: runtimeConfig.IsRequestBodyStrict) + }; + + // Store tags in a dictionary keyed by normalized REST path to ensure we can + // reuse the same tag instances in BuildPaths, preventing duplicate groups in Swagger UI. + Dictionary globalTagsDict = new(); + foreach (KeyValuePair kvp in runtimeConfig.Entities) + { + Entity entity = kvp.Value; + if (!entity.Rest.Enabled || !HasAnyAvailableOperations(entity, role)) + { + continue; + } + + // Use GetEntityRestPath to ensure consistent path normalization (with leading slash trimmed) + // matching the same computation used in BuildPaths. + string restPath = GetEntityRestPath(entity.Rest, kvp.Key); + + // First entity's description wins when multiple entities share the same REST path. + globalTagsDict.TryAdd(restPath, new OpenApiTag + { + Name = restPath, + Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description + }); + } + + return new OpenApiDocument() + { + Info = new OpenApiInfo + { + Version = ProductInfo.GetProductVersion(), + Title = title + }, + Servers = new List + { + new() { Url = url } + }, + Paths = BuildPaths(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName, globalTagsDict, role), + Components = components, + Tags = globalTagsDict.Values.ToList() + }; + } + /// /// Creates an OpenAPI description document using OpenAPI.NET. /// Document compliant with patches of OpenAPI V3.0 spec 3.0.0 and 3.0.1, @@ -131,47 +270,7 @@ public void CreateDocument(bool doOverrideExistingDocument = false) try { - string restEndpointPath = runtimeConfig.RestPath; - string? runtimeBaseRoute = runtimeConfig.Runtime?.BaseRoute; - string url = string.IsNullOrEmpty(runtimeBaseRoute) ? restEndpointPath : runtimeBaseRoute + "/" + restEndpointPath; - OpenApiComponents components = new() - { - Schemas = CreateComponentSchemas(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName) - }; - - // Collect all entity tags and their descriptions for the top-level tags array - // Store tags in a dictionary to ensure we can reuse the same tag instances in BuildPaths - Dictionary globalTagsDict = new(); - foreach (KeyValuePair kvp in runtimeConfig.Entities) - { - Entity entity = kvp.Value; - // Use GetEntityRestPath to ensure consistent path computation (with leading slash trimmed) - string restPath = GetEntityRestPath(entity.Rest, kvp.Key); - - // Only add the tag if it hasn't been added yet (handles entities with the same REST path) - // First entity's description wins when multiple entities share the same REST path. - globalTagsDict.TryAdd(restPath, new OpenApiTag - { - Name = restPath, - Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description - }); - } - - OpenApiDocument doc = new() - { - Info = new OpenApiInfo - { - Version = ProductInfo.GetProductVersion(), - Title = DOCUMENTOR_UI_TITLE - }, - Servers = new List - { - new() { Url = url } - }, - Paths = BuildPaths(runtimeConfig.Entities, runtimeConfig.DefaultDataSourceName, globalTagsDict), - Components = components, - Tags = globalTagsDict.Values.ToList() - }; + OpenApiDocument doc = BuildOpenApiDocument(runtimeConfig, role: null, title: DOCUMENTOR_UI_TITLE); _openApiDocument = doc; } catch (Exception ex) @@ -198,8 +297,10 @@ public void CreateDocument(bool doOverrideExistingDocument = false) /// A path with no primary key nor parameter representing the primary key value: /// "/EntityName" /// + /// Dictionary of global tags keyed by normalized REST path for reuse. + /// Optional role to filter permissions. If null, returns superset of all roles. /// All possible paths in the DAB engine's REST API endpoint. - private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSourceName, Dictionary globalTags) + private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSourceName, Dictionary globalTags, string? role = null) { OpenApiPaths pathsCollection = new(); @@ -233,26 +334,26 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour continue; } - // Reuse the existing tag from the global tags dictionary instead of creating a new one - // This ensures Swagger UI displays only one group per entity - List tags = new(); - if (globalTags.TryGetValue(entityRestPath, out OpenApiTag? existingTag)) + // Reuse the existing tag from the global tags dictionary instead of creating a new instance. + // This ensures Swagger UI displays only one group per entity by using the same object reference. + if (!globalTags.TryGetValue(entityRestPath, out OpenApiTag? existingTag)) { - tags.Add(existingTag); + _logger.LogWarning("Tag for REST path '{EntityRestPath}' not found in global tags dictionary. This indicates a key mismatch between BuildOpenApiDocument and BuildPaths.", entityRestPath); + continue; } - else + + List tags = new() { - // Fallback: create a new tag if not found in global tags. - // This should not happen in normal flow if GetEntityRestPath is used consistently. - // If this path is reached, it indicates a key mismatch between CreateDocument and BuildPaths. - tags.Add(new OpenApiTag - { - Name = entityRestPath, - Description = string.IsNullOrWhiteSpace(entity.Description) ? null : entity.Description - }); - } + existingTag + }; + + Dictionary configuredRestOperations = GetConfiguredRestOperations(entity, dbObject, role); - Dictionary configuredRestOperations = GetConfiguredRestOperations(entity, dbObject); + // Skip entities with no available operations + if (!configuredRestOperations.ContainsValue(true)) + { + continue; + } if (dbObject.SourceType is EntitySourceType.StoredProcedure) { @@ -262,12 +363,15 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour configuredRestOperations: configuredRestOperations, tags: tags); - OpenApiPathItem openApiPathItem = new() + if (operations.Count > 0) { - Operations = operations - }; + OpenApiPathItem openApiPathItem = new() + { + Operations = operations + }; - pathsCollection.TryAdd(entityBasePathComponent, openApiPathItem); + pathsCollection.TryAdd(entityBasePathComponent, openApiPathItem); + } } else { @@ -277,33 +381,41 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour entityName: entityName, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: true, + configuredRestOperations: configuredRestOperations, tags: tags); - Tuple> pkComponents = CreatePrimaryKeyPathComponentAndParameters(entityName, metadataProvider); - string pkPathComponents = pkComponents.Item1; - string fullPathComponent = entityBasePathComponent + pkPathComponents; - - OpenApiPathItem openApiPkPathItem = new() + if (pkOperations.Count > 0) { - Operations = pkOperations, - Parameters = pkComponents.Item2 - }; + Tuple> pkComponents = CreatePrimaryKeyPathComponentAndParameters(entityName, metadataProvider); + string pkPathComponents = pkComponents.Item1; + string fullPathComponent = entityBasePathComponent + pkPathComponents; - pathsCollection.TryAdd(fullPathComponent, openApiPkPathItem); + OpenApiPathItem openApiPkPathItem = new() + { + Operations = pkOperations, + Parameters = pkComponents.Item2 + }; + + pathsCollection.TryAdd(fullPathComponent, openApiPkPathItem); + } // Operations excluding primary key Dictionary operations = CreateOperations( entityName: entityName, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: false, + configuredRestOperations: configuredRestOperations, tags: tags); - OpenApiPathItem openApiPathItem = new() + if (operations.Count > 0) { - Operations = operations - }; + OpenApiPathItem openApiPathItem = new() + { + Operations = operations + }; - pathsCollection.TryAdd(entityBasePathComponent, openApiPathItem); + pathsCollection.TryAdd(entityBasePathComponent, openApiPathItem); + } } } @@ -319,6 +431,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour /// a path containing primary key parameters. /// TRUE: GET (one), PUT, PATCH, DELETE /// FALSE: GET (Many), POST + /// Operations available based on permissions. /// Tags denoting how the operations should be categorized. /// Typically one tag value, the entity's REST path. /// Collection of operation types and associated definitions. @@ -326,67 +439,100 @@ private Dictionary CreateOperations( string entityName, SourceDefinition sourceDefinition, bool includePrimaryKeyPathComponent, + Dictionary configuredRestOperations, List tags) { Dictionary openApiPathItemOperations = new(); if (includePrimaryKeyPathComponent) { - // The OpenApiResponses dictionary key represents the integer value of the HttpStatusCode, - // which is returned when using Enum.ToString("D"). - // The "D" format specified "displays the enumeration entry as an integer value in the shortest representation possible." - // It will only contain $select query parameter to allow the user to specify which fields to return. - OpenApiOperation getOperation = CreateBaseOperation(description: GETONE_DESCRIPTION, tags: tags); - AddQueryParameters(getOperation.Parameters); - getOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); - openApiPathItemOperations.Add(OperationType.Get, getOperation); + if (configuredRestOperations[OperationType.Get]) + { + OpenApiOperation getOperation = CreateBaseOperation(description: GETONE_DESCRIPTION, tags: tags); + AddQueryParameters(getOperation.Parameters); + getOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); + openApiPathItemOperations.Add(OperationType.Get, getOperation); + } - // PUT and PATCH requests have the same criteria for decided whether a request body is required. - bool requestBodyRequired = IsRequestBodyRequired(sourceDefinition, considerPrimaryKeys: false); + // Only calculate requestBodyRequired if PUT or PATCH operations are configured + if (configuredRestOperations[OperationType.Put] || configuredRestOperations[OperationType.Patch]) + { + bool requestBodyRequired = IsRequestBodyRequired(sourceDefinition, considerPrimaryKeys: false); - // PUT requests must include the primary key(s) in the URI path and exclude from the request body, - // independent of whether the PK(s) are autogenerated. - OpenApiOperation putOperation = CreateBaseOperation(description: PUT_DESCRIPTION, tags: tags); - putOperation.RequestBody = CreateOpenApiRequestBodyPayload($"{entityName}_NoPK", requestBodyRequired); - putOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); - putOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); - openApiPathItemOperations.Add(OperationType.Put, putOperation); + if (configuredRestOperations[OperationType.Put]) + { + OpenApiOperation putOperation = CreateBaseOperation(description: PUT_DESCRIPTION, tags: tags); + putOperation.RequestBody = CreateOpenApiRequestBodyPayload($"{entityName}_NoPK", requestBodyRequired); + putOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); + putOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); + openApiPathItemOperations.Add(OperationType.Put, putOperation); + } - // PATCH requests must include the primary key(s) in the URI path and exclude from the request body, - // independent of whether the PK(s) are autogenerated. - OpenApiOperation patchOperation = CreateBaseOperation(description: PATCH_DESCRIPTION, tags: tags); - patchOperation.RequestBody = CreateOpenApiRequestBodyPayload($"{entityName}_NoPK", requestBodyRequired); - patchOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); - patchOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); - openApiPathItemOperations.Add(OperationType.Patch, patchOperation); + if (configuredRestOperations[OperationType.Patch]) + { + OpenApiOperation patchOperation = CreateBaseOperation(description: PATCH_DESCRIPTION, tags: tags); + patchOperation.RequestBody = CreateOpenApiRequestBodyPayload($"{entityName}_NoPK", requestBodyRequired); + patchOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); + patchOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); + openApiPathItemOperations.Add(OperationType.Patch, patchOperation); + } + } - OpenApiOperation deleteOperation = CreateBaseOperation(description: DELETE_DESCRIPTION, tags: tags); - deleteOperation.Responses.Add(HttpStatusCode.NoContent.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.NoContent))); - openApiPathItemOperations.Add(OperationType.Delete, deleteOperation); + if (configuredRestOperations[OperationType.Delete]) + { + OpenApiOperation deleteOperation = CreateBaseOperation(description: DELETE_DESCRIPTION, tags: tags); + deleteOperation.Responses.Add(HttpStatusCode.NoContent.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.NoContent))); + openApiPathItemOperations.Add(OperationType.Delete, deleteOperation); + } return openApiPathItemOperations; } else { - // Primary key(s) are not included in the URI paths of the GET (all) and POST operations. - OpenApiOperation getAllOperation = CreateBaseOperation(description: GETALL_DESCRIPTION, tags: tags); - AddQueryParameters(getAllOperation.Parameters); - getAllOperation.Responses.Add( - HttpStatusCode.OK.ToString("D"), - CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName, includeNextLink: true)); - openApiPathItemOperations.Add(OperationType.Get, getAllOperation); - - // The POST body must include fields for primary key(s) which are not autogenerated because a value must be supplied - // for those fields. {entityName}_NoAutoPK represents the schema component which has all fields except for autogenerated primary keys. - // When no autogenerated primary keys exist, then all fields can be included in the POST body which is represented by the schema - // component: {entityName}. - string postBodySchemaReferenceId = DoesSourceContainAutogeneratedPrimaryKey(sourceDefinition) ? $"{entityName}_NoAutoPK" : $"{entityName}"; - - OpenApiOperation postOperation = CreateBaseOperation(description: POST_DESCRIPTION, tags: tags); - postOperation.RequestBody = CreateOpenApiRequestBodyPayload(postBodySchemaReferenceId, IsRequestBodyRequired(sourceDefinition, considerPrimaryKeys: true)); - postOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); - postOperation.Responses.Add(HttpStatusCode.Conflict.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Conflict))); - openApiPathItemOperations.Add(OperationType.Post, postOperation); + if (configuredRestOperations[OperationType.Get]) + { + OpenApiOperation getAllOperation = CreateBaseOperation(description: GETALL_DESCRIPTION, tags: tags); + AddQueryParameters(getAllOperation.Parameters); + getAllOperation.Responses.Add( + HttpStatusCode.OK.ToString("D"), + CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName, includeNextLink: true)); + openApiPathItemOperations.Add(OperationType.Get, getAllOperation); + } + + if (configuredRestOperations[OperationType.Post]) + { + string postBodySchemaReferenceId = DoesSourceContainAutogeneratedPrimaryKey(sourceDefinition) ? $"{entityName}_NoAutoPK" : $"{entityName}"; + OpenApiOperation postOperation = CreateBaseOperation(description: POST_DESCRIPTION, tags: tags); + postOperation.RequestBody = CreateOpenApiRequestBodyPayload(postBodySchemaReferenceId, IsRequestBodyRequired(sourceDefinition, considerPrimaryKeys: true)); + postOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); + postOperation.Responses.Add(HttpStatusCode.Conflict.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Conflict))); + openApiPathItemOperations.Add(OperationType.Post, postOperation); + } + + // For entities with auto-generated primary keys, add keyless PUT and PATCH operations. + // These routes allow creating records without specifying the primary key in the URL, + // which is useful for entities with identity/auto-generated keys. + if (DoesSourceContainAutogeneratedPrimaryKey(sourceDefinition)) + { + string keylessBodySchemaReferenceId = $"{entityName}_NoAutoPK"; + bool keylessRequestBodyRequired = IsRequestBodyRequired(sourceDefinition, considerPrimaryKeys: true); + + if (configuredRestOperations[OperationType.Put]) + { + OpenApiOperation putKeylessOperation = CreateBaseOperation(description: PUT_PATCH_KEYLESS_DESCRIPTION, tags: tags); + putKeylessOperation.RequestBody = CreateOpenApiRequestBodyPayload(keylessBodySchemaReferenceId, keylessRequestBodyRequired); + putKeylessOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); + openApiPathItemOperations.Add(OperationType.Put, putKeylessOperation); + } + + if (configuredRestOperations[OperationType.Patch]) + { + OpenApiOperation patchKeylessOperation = CreateBaseOperation(description: PUT_PATCH_KEYLESS_DESCRIPTION, tags: tags); + patchKeylessOperation.RequestBody = CreateOpenApiRequestBodyPayload(keylessBodySchemaReferenceId, keylessRequestBodyRequired); + patchKeylessOperation.Responses.Add(HttpStatusCode.Created.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.Created), responseObjectSchemaName: entityName)); + openApiPathItemOperations.Add(OperationType.Patch, patchKeylessOperation); + } + } return openApiPathItemOperations; } @@ -636,8 +782,9 @@ private static OpenApiParameter GetOpenApiQueryParameter(string name, string des /// /// The entity. /// Database object metadata, indicating entity SourceType + /// Optional role to filter permissions. If null, returns superset of all roles. /// Collection of OpenAPI OperationTypes and whether they should be created. - private static Dictionary GetConfiguredRestOperations(Entity entity, DatabaseObject dbObject) + private static Dictionary GetConfiguredRestOperations(Entity entity, DatabaseObject dbObject, string? role = null) { Dictionary configuredOperations = new() { @@ -691,16 +838,168 @@ private static Dictionary GetConfiguredRestOperations(Entit } else { - configuredOperations[OperationType.Get] = true; - configuredOperations[OperationType.Post] = true; - configuredOperations[OperationType.Put] = true; - configuredOperations[OperationType.Patch] = true; - configuredOperations[OperationType.Delete] = true; + // For tables/views, determine available operations from permissions + // If role is specified, filter to that role only; otherwise, get superset of all roles + // Note: PUT/PATCH require BOTH Create AND Update permissions (upsert semantics) + if (entity?.Permissions is not null) + { + bool hasCreate = false; + bool hasUpdate = false; + + foreach (EntityPermission permission in entity.Permissions) + { + // Skip permissions for other roles if a specific role is requested + if (role is not null && !string.Equals(permission.Role, role, StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + if (permission.Actions is null) + { + continue; + } + + foreach (EntityAction action in permission.Actions) + { + if (action.Action == EntityActionOperation.All) + { + configuredOperations[OperationType.Get] = true; + configuredOperations[OperationType.Post] = true; + configuredOperations[OperationType.Delete] = true; + hasCreate = true; + hasUpdate = true; + } + else + { + switch (action.Action) + { + case EntityActionOperation.Read: + configuredOperations[OperationType.Get] = true; + break; + case EntityActionOperation.Create: + configuredOperations[OperationType.Post] = true; + hasCreate = true; + break; + case EntityActionOperation.Update: + hasUpdate = true; + break; + case EntityActionOperation.Delete: + configuredOperations[OperationType.Delete] = true; + break; + } + } + } + } + + // PUT/PATCH require both Create and Update permissions (upsert semantics) + if (hasCreate && hasUpdate) + { + configuredOperations[OperationType.Put] = true; + configuredOperations[OperationType.Patch] = true; + } + } } return configuredOperations; } + /// + /// Checks if an entity has any available REST operations based on its permissions. + /// + /// The entity to check. + /// Optional role to filter permissions. If null, checks all roles. + /// True if the entity has any available operations. + private static bool HasAnyAvailableOperations(Entity entity, string? role = null) + { + if (entity?.Permissions is null || entity.Permissions.Length == 0) + { + return false; + } + + foreach (EntityPermission permission in entity.Permissions) + { + // Skip permissions for other roles if a specific role is requested + if (role is not null && !string.Equals(permission.Role, role, StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + if (permission.Actions?.Length > 0) + { + return true; + } + } + + return false; + } + + /// + /// Filters the exposed column names based on the superset of available fields across role permissions. + /// A field is included if at least one role (or the specified role) has access to it. + /// + /// The entity to check permissions for. + /// All exposed column names from the database. + /// Optional role to filter permissions. If null, returns superset of all roles. + /// Filtered set of column names that are available based on permissions. + private static HashSet FilterFieldsByPermissions(Entity entity, HashSet exposedColumnNames, string? role = null) + { + if (entity?.Permissions is null || entity.Permissions.Length == 0) + { + return exposedColumnNames; + } + + HashSet availableFields = new(); + + foreach (EntityPermission permission in entity.Permissions) + { + // Skip permissions for other roles if a specific role is requested + if (role is not null && !string.Equals(permission.Role, role, StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + // If actions is not defined for a matching role, all fields are available + if (permission.Actions is null) + { + return exposedColumnNames; + } + + foreach (EntityAction action in permission.Actions) + { + // If Fields is null, all fields are available for this action + if (action.Fields is null) + { + availableFields.UnionWith(exposedColumnNames); + continue; + } + + // Determine included fields using ternary - either all fields or explicitly listed + HashSet actionFields = (action.Fields.Include is null || action.Fields.Include.Contains("*")) + ? new HashSet(exposedColumnNames) + : new HashSet(action.Fields.Include.Where(f => exposedColumnNames.Contains(f))); + + // Remove excluded fields + if (action.Fields.Exclude is not null && action.Fields.Exclude.Count > 0) + { + if (action.Fields.Exclude.Contains("*")) + { + // Exclude all - no fields available for this action + actionFields.Clear(); + } + else + { + actionFields.ExceptWith(action.Fields.Exclude); + } + } + + // Add to superset of available fields + availableFields.UnionWith(actionFields); + } + } + + return availableFields; + } + /// /// Creates the request body definition, which includes the expected media type (application/json) /// and reference to request body schema. @@ -988,8 +1287,10 @@ private static OpenApiMediaType CreateResponseContainer(string responseObjectSch /// 3) {EntityName}_NoPK -> No primary keys present in schema, used for POST requests where PK is autogenerated and GET (all). /// Schema objects can be referenced elsewhere in the OpenAPI document with the intent to reduce document verbosity. /// + /// Optional role to filter permissions. If null, returns superset of all roles. + /// When true, request body schemas disallow extra fields. /// Collection of schemas for entities defined in the runtime configuration. - private Dictionary CreateComponentSchemas(RuntimeEntities entities, string defaultDataSourceName) + private Dictionary CreateComponentSchemas(RuntimeEntities entities, string defaultDataSourceName, string? role = null, bool isRequestBodyStrict = true) { Dictionary schemas = new(); // for rest scenario we need the default datasource name. @@ -1002,69 +1303,90 @@ private Dictionary CreateComponentSchemas(RuntimeEntities string entityName = entityDbMetadataMap.Key; DatabaseObject dbObject = entityDbMetadataMap.Value; - if (!entities.TryGetValue(entityName, out Entity? entity) || !entity.Rest.Enabled) + if (!entities.TryGetValue(entityName, out Entity? entity) || !entity.Rest.Enabled || !HasAnyAvailableOperations(entity, role)) { // Don't create component schemas for: // 1. Linking entity: The entity will be null when we are dealing with a linking entity, which is not exposed in the config. // 2. Entity for which REST endpoint is disabled. + // 3. Entity with no available operations based on permissions. continue; } SourceDefinition sourceDefinition = metadataProvider.GetSourceDefinition(entityName); HashSet exposedColumnNames = GetExposedColumnNames(entityName, sourceDefinition.Columns.Keys.ToList(), metadataProvider); + + // Filter fields based on the superset of permissions across all roles (or specific role) + exposedColumnNames = FilterFieldsByPermissions(entity, exposedColumnNames, role); + + // Get configured operations to determine which schemas to generate + Dictionary configuredOps = GetConfiguredRestOperations(entity, dbObject, role); + bool hasPostOperation = configuredOps.GetValueOrDefault(OperationType.Post); + bool hasPutPatchOperation = configuredOps.GetValueOrDefault(OperationType.Put) || configuredOps.GetValueOrDefault(OperationType.Patch); + HashSet nonAutoGeneratedPKColumnNames = new(); if (dbObject.SourceType is EntitySourceType.StoredProcedure) { - // Request body schema whose properties map to stored procedure parameters - DatabaseStoredProcedure spObject = (DatabaseStoredProcedure)dbObject; - schemas.Add(entityName + SP_REQUEST_SUFFIX, CreateSpRequestComponentSchema(fields: spObject.StoredProcedureDefinition.Parameters)); + // Only generate request body schema if SP has operations that use it + if (hasPostOperation || hasPutPatchOperation) + { + DatabaseStoredProcedure spObject = (DatabaseStoredProcedure)dbObject; + schemas.Add(entityName + SP_REQUEST_SUFFIX, CreateSpRequestComponentSchema(fields: spObject.StoredProcedureDefinition.Parameters, isRequestBodyStrict: isRequestBodyStrict)); + } // Response body schema whose properties map to the stored procedure's first result set columns // as described by sys.dm_exec_describe_first_result_set. - schemas.Add(entityName + SP_RESPONSE_SUFFIX, CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities)); + // Response schemas don't need additionalProperties restriction + schemas.Add(entityName + SP_RESPONSE_SUFFIX, CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities, isRequestBodySchema: false)); } else { // Create component schema for FULL entity with all primary key columns (included auto-generated) // which will typically represent the response body of a request or a stored procedure's request body. - schemas.Add(entityName, CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities)); + // Response schemas don't need additionalProperties restriction + schemas.Add(entityName, CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities, isRequestBodySchema: false)); - // Create an entity's request body component schema excluding autogenerated primary keys. - // A POST request requires any non-autogenerated primary key references to be in the request body. - foreach (string primaryKeyColumn in sourceDefinition.PrimaryKey) + // Only generate request body schemas if mutation operations are available + if (hasPostOperation || hasPutPatchOperation) { - // Non-Autogenerated primary key(s) should appear in the request body. - if (!sourceDefinition.Columns[primaryKeyColumn].IsAutoGenerated) + // Create an entity's request body component schema excluding autogenerated primary keys. + // A POST request requires any non-autogenerated primary key references to be in the request body. + foreach (string primaryKeyColumn in sourceDefinition.PrimaryKey) { - nonAutoGeneratedPKColumnNames.Add(primaryKeyColumn); - continue; - } + // Non-Autogenerated primary key(s) should appear in the request body. + if (!sourceDefinition.Columns[primaryKeyColumn].IsAutoGenerated) + { + nonAutoGeneratedPKColumnNames.Add(primaryKeyColumn); + continue; + } - if (metadataProvider.TryGetExposedColumnName(entityName, backingFieldName: primaryKeyColumn, out string? exposedColumnName) - && exposedColumnName is not null) - { - exposedColumnNames.Remove(exposedColumnName); + if (metadataProvider.TryGetExposedColumnName(entityName, backingFieldName: primaryKeyColumn, out string? exposedColumnName) + && exposedColumnName is not null) + { + exposedColumnNames.Remove(exposedColumnName); + } } - } - schemas.Add($"{entityName}_NoAutoPK", CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities)); + // Request body schema for POST - apply additionalProperties based on strict mode + schemas.Add($"{entityName}_NoAutoPK", CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities, isRequestBodySchema: true, isRequestBodyStrict: isRequestBodyStrict)); - // Create an entity's request body component schema excluding all primary keys - // by removing the tracked non-autogenerated primary key column names and removing them from - // the exposedColumnNames collection. - // The schema component without primary keys is used for PUT and PATCH operation request bodies because - // those operations require all primary key references to be in the URI path, not the request body. - foreach (string primaryKeyColumn in nonAutoGeneratedPKColumnNames) - { - if (metadataProvider.TryGetExposedColumnName(entityName, backingFieldName: primaryKeyColumn, out string? exposedColumnName) - && exposedColumnName is not null) + // Create an entity's request body component schema excluding all primary keys + // by removing the tracked non-autogenerated primary key column names and removing them from + // the exposedColumnNames collection. + // The schema component without primary keys is used for PUT and PATCH operation request bodies because + // those operations require all primary key references to be in the URI path, not the request body. + foreach (string primaryKeyColumn in nonAutoGeneratedPKColumnNames) { - exposedColumnNames.Remove(exposedColumnName); + if (metadataProvider.TryGetExposedColumnName(entityName, backingFieldName: primaryKeyColumn, out string? exposedColumnName) + && exposedColumnName is not null) + { + exposedColumnNames.Remove(exposedColumnName); + } } - } - schemas.Add($"{entityName}_NoPK", CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities)); + // Request body schema for PUT/PATCH - apply additionalProperties based on strict mode + schemas.Add($"{entityName}_NoPK", CreateComponentSchema(entityName, fields: exposedColumnNames, metadataProvider, entities, isRequestBodySchema: true, isRequestBodyStrict: isRequestBodyStrict)); + } } } @@ -1077,10 +1399,10 @@ private Dictionary CreateComponentSchemas(RuntimeEntities /// Additionally, the property typeMetadata is sourced by converting the stored procedure /// parameter's SystemType to JsonDataType. /// - /// /// Collection of stored procedure parameter metadata. + /// When true, sets additionalProperties to false. /// OpenApiSchema object representing a stored procedure's request body. - private static OpenApiSchema CreateSpRequestComponentSchema(Dictionary fields) + private static OpenApiSchema CreateSpRequestComponentSchema(Dictionary fields, bool isRequestBodyStrict = true) { Dictionary properties = new(); HashSet required = new(); @@ -1108,7 +1430,9 @@ private static OpenApiSchema CreateSpRequestComponentSchema(DictionaryList of mapped (alias) field names. /// Metadata provider for database objects. /// Runtime entities from configuration. + /// Whether this schema is for a request body (applies additionalProperties setting). + /// When true and isRequestBodySchema, sets additionalProperties to false. /// Raised when an entity's database metadata can't be found, /// indicating a failure due to the provided entityName. /// Entity's OpenApiSchema representation. - private static OpenApiSchema CreateComponentSchema(string entityName, HashSet fields, ISqlMetadataProvider metadataProvider, RuntimeEntities entities) + private static OpenApiSchema CreateComponentSchema( + string entityName, + HashSet fields, + ISqlMetadataProvider metadataProvider, + RuntimeEntities entities, + bool isRequestBodySchema = false, + bool isRequestBodyStrict = true) { if (!metadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { @@ -1177,7 +1509,10 @@ private static OpenApiSchema CreateComponentSchema(string entityName, HashSet /// Upsert Request context containing the request body. + /// When true the primary key was provided in the URL route + /// and PK columns in the body are skipped (original behaviour). When false the primary key + /// is expected in the request body, so non-auto-generated PK columns must be present and + /// the full composite key (if applicable) must be supplied. /// - public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx) + public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx, bool isPrimaryKeyInUrl = true) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(upsertRequestCtx.EntityName); IEnumerable fieldsInRequestBody = upsertRequestCtx.FieldValuePairsInBody.Keys; @@ -385,13 +389,45 @@ public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx) unValidatedFields.Remove(exposedName!); } - // Primary Key(s) should not be present in the request body. We do not fail a request - // if a PK is autogenerated here, because an UPSERT request may only need to update a - // record. If an insert occurs on a table with autogenerated primary key, - // a database error will be returned. + // When the primary key is provided in the URL route, skip PK columns in body validation. + // When the primary key is NOT in the URL (body-based PK), we need to validate that + // all non-auto-generated PK columns are present in the body to form a complete key. if (sourceDefinition.PrimaryKey.Contains(column.Key)) { - continue; + if (isPrimaryKeyInUrl) + { + continue; + } + else + { + // Body-based PK: non-auto-generated PK columns MUST be present. + // Auto-generated PK columns are skipped — they cannot be supplied by the caller. + if (column.Value.IsAutoGenerated) + { + continue; + } + + if (!fieldsInRequestBody.Contains(exposedName)) + { + throw new DataApiBuilderException( + message: $"Invalid request body. Missing field in body: {exposedName}.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + // PK value must not be null for non-nullable PK columns. + if (!column.Value.IsNullable && + upsertRequestCtx.FieldValuePairsInBody[exposedName!] is null) + { + throw new DataApiBuilderException( + message: $"Invalid value for field {exposedName} in request body.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + unValidatedFields.Remove(exposedName!); + continue; + } } // Request body must have value defined for included non-nullable columns @@ -488,7 +524,6 @@ public void ValidateEntity(string entityName) /// Tries to get the table definition for the given entity from the Metadata provider. /// /// Target entity name. - /// enables referencing DB schema. /// private static SourceDefinition TryGetSourceDefinition(string entityName, ISqlMetadataProvider sqlMetadataProvider) diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 6a2308dd83..2bfab7e05f 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -70,24 +70,25 @@ RequestValidator requestValidator ISqlMetadataProvider sqlMetadataProvider = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName); DatabaseObject dbObject = sqlMetadataProvider.EntityToDatabaseObject[entityName]; - if (dbObject.SourceType is not EntitySourceType.StoredProcedure) - { - await AuthorizationCheckForRequirementAsync(resource: entityName, requirement: new EntityRoleOperationPermissionsRequirement()); - } - else - { - await AuthorizationCheckForRequirementAsync(resource: entityName, requirement: new StoredProcedurePermissionsRequirement()); - } - QueryString? query = GetHttpContext().Request.QueryString; string queryString = query is null ? string.Empty : GetHttpContext().Request.QueryString.ToString(); + // Read the request body early so it can be used for downstream processing. string requestBody = string.Empty; using (StreamReader reader = new(GetHttpContext().Request.Body)) { requestBody = await reader.ReadToEndAsync(); } + if (dbObject.SourceType is not EntitySourceType.StoredProcedure) + { + await AuthorizationCheckForRequirementAsync(resource: entityName, requirement: new EntityRoleOperationPermissionsRequirement()); + } + else + { + await AuthorizationCheckForRequirementAsync(resource: entityName, requirement: new StoredProcedurePermissionsRequirement()); + } + RestRequestContext context; // If request has resolved to a stored procedure entity, initialize and validate appropriate request context @@ -144,7 +145,21 @@ RequestValidator requestValidator case EntityActionOperation.UpdateIncremental: case EntityActionOperation.Upsert: case EntityActionOperation.UpsertIncremental: - RequestValidator.ValidatePrimaryKeyRouteAndQueryStringInURL(operationType, primaryKeyRoute); + // For Upsert/UpsertIncremental, a keyless URL is allowed. When the + // primary key route is absent, ValidateUpsertRequestContext checks that + // the body contains all non-auto-generated PK columns so the mutation + // engine can resolve the target row (or insert a new one). + // Update/UpdateIncremental always require the PK in the URL. + if (!string.IsNullOrEmpty(primaryKeyRoute)) + { + RequestValidator.ValidatePrimaryKeyRouteAndQueryStringInURL(operationType, primaryKeyRoute); + } + else if (operationType is not EntityActionOperation.Upsert and + not EntityActionOperation.UpsertIncremental) + { + RequestValidator.ValidatePrimaryKeyRouteAndQueryStringInURL(operationType, primaryKeyRoute); + } + JsonElement upsertPayloadRoot = RequestValidator.ValidateAndParseRequestBody(requestBody); context = new UpsertRequestContext( entityName, @@ -153,7 +168,9 @@ RequestValidator requestValidator operationType); if (context.DatabaseObject.SourceType is EntitySourceType.Table) { - _requestValidator.ValidateUpsertRequestContext((UpsertRequestContext)context); + _requestValidator.ValidateUpsertRequestContext( + (UpsertRequestContext)context, + isPrimaryKeyInUrl: !string.IsNullOrEmpty(primaryKeyRoute)); } break; @@ -174,6 +191,7 @@ RequestValidator requestValidator if (!string.IsNullOrWhiteSpace(queryString)) { + context.RawQueryString = queryString; context.ParsedQueryString = HttpUtility.ParseQueryString(queryString); RequestParser.ParseQueryString(context, sqlMetadataProvider); } @@ -277,6 +295,7 @@ private void PopulateStoredProcedureContext( // So, $filter will be treated as any other parameter (inevitably will raise a Bad Request) if (!string.IsNullOrWhiteSpace(queryString)) { + context.RawQueryString = queryString; context.ParsedQueryString = HttpUtility.ParseQueryString(queryString); } @@ -433,11 +452,17 @@ public bool TryGetRestRouteFromConfig([NotNullWhen(true)] out string? configured /// /// Tries to get the Entity name and primary key route from the provided string - /// returns the entity name via a lookup using the string which includes - /// characters up until the first '/', and then resolves the primary key - /// as the substring following the '/'. + /// by matching against configured entity paths (which may include '/' for sub-directories) + /// using longest-prefix matching, then treating the remaining suffix as the primary key route. + /// /// For example, a request route should be of the form /// {EntityPath}/{PKColumn}/{PkValue}/{PKColumn}/{PKValue}... + /// where {EntityPath} may be a single segment like "books" or multi-segment like "shopping-cart/item". + /// + /// Uses longest-prefix matching (most-specific match wins). When multiple + /// entity paths could match, the longest matching path takes precedence. For example, + /// if both "cart" and "cart/item" are valid entity paths, a request to + /// "cart/item/id/123" will match "cart/item" with primaryKeyRoute "id/123". /// /// The request route (no '/' prefix) containing the entity path /// (and optionally primary key). @@ -448,26 +473,27 @@ public bool TryGetRestRouteFromConfig([NotNullWhen(true)] out string? configured RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); - // Split routeAfterPath on the first occurrence of '/', if we get back 2 elements - // this means we have a non-empty primary key route which we save. Otherwise, save - // primary key route as empty string. Entity Path will always be the element at index 0. - // ie: {EntityPath}/{PKColumn}/{PkValue}/{PKColumn}/{PKValue}... - // splits into [{EntityPath}] when there is an empty primary key route and into - // [{EntityPath}, {Primarykeyroute}] when there is a non-empty primary key route. - int maxNumberOfElementsFromSplit = 2; - string[] entityPathAndPKRoute = routeAfterPathBase.Split(new[] { '/' }, maxNumberOfElementsFromSplit); - string entityPath = entityPathAndPKRoute[0]; - string primaryKeyRoute = entityPathAndPKRoute.Length == maxNumberOfElementsFromSplit ? entityPathAndPKRoute[1] : string.Empty; - - if (!runtimeConfig.TryGetEntityNameFromPath(entityPath, out string? entityName)) + // Split routeAfterPath to extract segments + string[] segments = routeAfterPathBase.Split('/'); + + // Try longest paths first (most-specific match wins) + // Start with all segments, then remove one at a time + for (int i = segments.Length; i >= 1; i--) { - throw new DataApiBuilderException( - message: $"Invalid Entity path: {entityPath}.", - statusCode: HttpStatusCode.NotFound, - subStatusCode: DataApiBuilderException.SubStatusCodes.EntityNotFound); + string entityPath = string.Join("/", segments.Take(i)); + if (runtimeConfig.TryGetEntityNameFromPath(entityPath, out string? entityName)) + { + // Found entity + string primaryKeyRoute = i < segments.Length ? string.Join("/", segments.Skip(i)) : string.Empty; + return (entityName!, primaryKeyRoute); + } } - return (entityName!, primaryKeyRoute); + // No entity found - show the full path for better debugging + throw new DataApiBuilderException( + message: $"Invalid Entity path: {routeAfterPathBase}.", + statusCode: HttpStatusCode.NotFound, + subStatusCode: DataApiBuilderException.SubStatusCodes.EntityNotFound); } /// diff --git a/src/Core/Telemetry/TelemetryTracesHelper.cs b/src/Core/Telemetry/TelemetryTracesHelper.cs index 01c5acbf51..152a2e6e68 100644 --- a/src/Core/Telemetry/TelemetryTracesHelper.cs +++ b/src/Core/Telemetry/TelemetryTracesHelper.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Net; using Azure.DataApiBuilder.Config.ObjectModel; -using OpenTelemetry.Trace; using Kestral = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod; namespace Azure.DataApiBuilder.Core.Telemetry @@ -104,12 +103,85 @@ public static void TrackMainControllerActivityFinishedWithException( { if (activity.IsAllDataRequested) { - activity.SetStatus(Status.Error.WithDescription(ex.Message)); - activity.RecordException(ex); + activity.SetStatus(ActivityStatusCode.Error, ex.Message); + activity.AddException(ex); activity.SetTag("error.type", ex.GetType().Name); activity.SetTag("error.message", ex.Message); activity.SetTag("status.code", statusCode); } } + + /// + /// Tracks the start of an MCP tool execution activity. + /// + /// The activity instance. + /// The name of the MCP tool being executed. + /// The entity name associated with the tool (optional). + /// The operation being performed (e.g., execute, read, create). + /// The database procedure being executed (optional, schema-qualified if available). + public static void TrackMcpToolExecutionStarted( + this Activity activity, + string toolName, + string? entityName = null, + string? operation = null, + string? dbProcedure = null) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("mcp.tool.name", toolName); + + if (!string.IsNullOrEmpty(entityName)) + { + activity.SetTag("dab.entity", entityName); + } + + if (!string.IsNullOrEmpty(operation)) + { + activity.SetTag("dab.operation", operation); + } + + if (!string.IsNullOrEmpty(dbProcedure)) + { + activity.SetTag("db.procedure", dbProcedure); + } + } + } + + /// + /// Tracks the successful completion of an MCP tool execution. + /// + /// The activity instance. + public static void TrackMcpToolExecutionFinished(this Activity activity) + { + if (activity.IsAllDataRequested) + { + activity.SetStatus(ActivityStatusCode.Ok); + } + } + + /// + /// Tracks the completion of an MCP tool execution with an exception. + /// + /// The activity instance. + /// The exception that occurred. + /// Optional error code for the failure. + public static void TrackMcpToolExecutionFinishedWithException( + this Activity activity, + Exception ex, + string? errorCode = null) + { + if (activity.IsAllDataRequested) + { + activity.SetStatus(ActivityStatusCode.Error, ex.Message); + activity.AddException(ex); + activity.SetTag("error.type", ex.GetType().Name); + activity.SetTag("error.message", ex.Message); + + if (!string.IsNullOrEmpty(errorCode)) + { + activity.SetTag("error.code", errorCode); + } + } + } } } diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 26ad392ae8..c09275c61b 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -2,7 +2,7 @@ enable ..\out - 1.7 + 2.0 diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index ccd69b9600..dfd605cc8f 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -7,7 +7,7 @@ - + @@ -34,19 +34,20 @@ + - - - - - - - + + + + + + + - + diff --git a/src/Service.Tests/Authentication/Helpers/RuntimeConfigAuthHelper.cs b/src/Service.Tests/Authentication/Helpers/RuntimeConfigAuthHelper.cs index 07a8a565ec..db3ae28de4 100644 --- a/src/Service.Tests/Authentication/Helpers/RuntimeConfigAuthHelper.cs +++ b/src/Service.Tests/Authentication/Helpers/RuntimeConfigAuthHelper.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#nullable enable + using System.Collections.Generic; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.ObjectModel; @@ -27,4 +29,29 @@ internal static RuntimeConfig CreateTestConfigWithAuthNProvider(AuthenticationOp ); return config; } + + internal static RuntimeConfig CreateTestConfigWithAuthNProviderAndUserDelegatedAuth( + AuthenticationOptions authenticationOptions, + UserDelegatedAuthOptions userDelegatedAuthOptions) + { + DataSource dataSource = new DataSource(DatabaseType.MSSQL, "", new Dictionary()) with + { + UserDelegatedAuth = userDelegatedAuthOptions + }; + + HostOptions hostOptions = new(Cors: null, Authentication: authenticationOptions); + RuntimeConfig config = new( + Schema: FileSystemRuntimeConfigLoader.SCHEMA, + DataSource: dataSource, + Runtime: new RuntimeOptions( + Rest: new RestRuntimeOptions(), + GraphQL: new GraphQLRuntimeOptions(), + Mcp: new McpRuntimeOptions(), + Host: hostOptions + ), + Entities: new(new Dictionary()) + ); + + return config; + } } diff --git a/src/Service.Tests/Authentication/JwtTokenAuthenticationUnitTests.cs b/src/Service.Tests/Authentication/JwtTokenAuthenticationUnitTests.cs index a805c3ab1a..789770c985 100644 --- a/src/Service.Tests/Authentication/JwtTokenAuthenticationUnitTests.cs +++ b/src/Service.Tests/Authentication/JwtTokenAuthenticationUnitTests.cs @@ -173,7 +173,15 @@ public async Task TestInvalidToken_BadAudience() Assert.AreEqual(expected: (int)HttpStatusCode.Unauthorized, actual: postMiddlewareContext.Response.StatusCode); Assert.IsFalse(postMiddlewareContext.User.Identity.IsAuthenticated); StringValues headerValue = GetChallengeHeader(postMiddlewareContext); - Assert.IsTrue(headerValue[0].Contains("invalid_token") && headerValue[0].Contains($"The audience '{BAD_AUDIENCE}' is invalid")); + + // Microsoft.IdentityModel.Tokens version 8.8+ scrubs the Audience from the error message + // This behavior can be disabled with AppContext.SetSwitch("Switch.Microsoft.IdentityModel.DoNotScrubExceptions", true); + // See https://aka.ms/identitymodel/app-context-switches + string expectedAudienceInErrorMessage = AppContext.TryGetSwitch("Switch.Microsoft.IdentityModel.DoNotScrubExceptions", out bool isExceptionScrubbingDisabled) && isExceptionScrubbingDisabled + ? BAD_AUDIENCE + : "(null)"; + + Assert.IsTrue(headerValue[0].Contains("invalid_token") && headerValue[0].Contains($"The audience '{expectedAudienceInErrorMessage}' is invalid")); } /// diff --git a/src/Service.Tests/Authentication/OboSqlTokenProviderUnitTests.cs b/src/Service.Tests/Authentication/OboSqlTokenProviderUnitTests.cs new file mode 100644 index 0000000000..dee163a36e --- /dev/null +++ b/src/Service.Tests/Authentication/OboSqlTokenProviderUnitTests.cs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Extensions.Logging; +using Microsoft.Identity.Client; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using ZiggyCreatures.Caching.Fusion; +using AuthenticationOptions = Azure.DataApiBuilder.Config.ObjectModel.AuthenticationOptions; + +namespace Azure.DataApiBuilder.Service.Tests.Authentication +{ + /// + /// Unit tests for which handles On-Behalf-Of (OBO) token + /// acquisition for delegated user authentication to Azure SQL Database. + /// Tests cover: input validation, claim extraction, token caching, scope formatting, and error handling. + /// + [TestClass] + public class OboSqlTokenProviderUnitTests + { + private const string TEST_DATABASE_AUDIENCE = "https://database.windows.net/"; + private const string TEST_SUBJECT_OID = "00000000-0000-0000-0000-000000000001"; + private const string TEST_SUBJECT_SUB = "00000000-0000-0000-0000-000000000002"; + private const string TEST_TENANT_ID = "11111111-1111-1111-1111-111111111111"; + private const string TEST_ACCESS_TOKEN = "mock-sql-access-token"; + private const string TEST_INCOMING_JWT = "incoming.jwt.assertion"; + + private Mock _msalMock; + private Mock> _loggerMock; + private IFusionCache _cache; + private OboSqlTokenProvider _provider; + + /// + /// Initializes mocks and provider before each test. + /// + [TestInitialize] + public void TestInit() + { + _msalMock = new Mock(); + _loggerMock = new Mock>(); + _cache = CreateFusionCache(); + _provider = new OboSqlTokenProvider(_msalMock.Object, _loggerMock.Object, _cache); + } + + /// + /// Cleanup FusionCache after each test. + /// + [TestCleanup] + public void TestCleanup() + { + (_cache as IDisposable)?.Dispose(); + } + + #region Input Validation Tests + + /// + /// Verifies that null/empty inputs return null without calling MSAL. + /// + [DataTestMethod] + [DataRow(null, TEST_INCOMING_JWT, DisplayName = "Null principal")] + [DataRow("valid", "", DisplayName = "Empty JWT assertion")] + [DataRow("valid", null, DisplayName = "Null JWT assertion")] + public async Task GetAccessTokenOnBehalfOfAsync_InvalidInput_ReturnsNull( + string principalMarker, string jwtAssertion) + { + // Arrange + ClaimsPrincipal principal = principalMarker == null + ? null + : CreatePrincipalWithOid(TEST_SUBJECT_OID, TEST_TENANT_ID); + + // Act + string result = await _provider.GetAccessTokenOnBehalfOfAsync( + principal: principal, + incomingJwtAssertion: jwtAssertion, + databaseAudience: TEST_DATABASE_AUDIENCE); + + // Assert + Assert.IsNull(result); + _msalMock.Verify( + m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never, + "MSAL should not be called for invalid input."); + } + + /// + /// Verifies that missing required claims (oid/sub or tid) throw DataApiBuilderException. + /// + [DataTestMethod] + [DataRow(null, null, TEST_TENANT_ID, "OBO_IDENTITY_CLAIMS_MISSING", DisplayName = "Missing oid and sub")] + [DataRow(TEST_SUBJECT_OID, null, null, "OBO_TENANT_CLAIM_MISSING", DisplayName = "Missing tenant id")] + public async Task GetAccessTokenOnBehalfOfAsync_MissingRequiredClaims_ThrowsUnauthorized( + string oid, string sub, string tid, string expectedErrorConstant) + { + // Arrange + ClaimsIdentity identity = new(); + if (oid != null) + { + identity.AddClaim(new Claim("oid", oid)); + } + + if (sub != null) + { + identity.AddClaim(new Claim("sub", sub)); + } + + if (tid != null) + { + identity.AddClaim(new Claim("tid", tid)); + } + + ClaimsPrincipal principal = new(identity); + + // Act & Assert + DataApiBuilderException ex = await Assert.ThrowsExceptionAsync( + async () => await _provider.GetAccessTokenOnBehalfOfAsync( + principal: principal, + incomingJwtAssertion: TEST_INCOMING_JWT, + databaseAudience: TEST_DATABASE_AUDIENCE)); + + Assert.AreEqual(HttpStatusCode.Unauthorized, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure, ex.SubStatusCode); + + // Verify the correct error message constant is used + string expectedMessage = expectedErrorConstant == "OBO_IDENTITY_CLAIMS_MISSING" + ? DataApiBuilderException.OBO_IDENTITY_CLAIMS_MISSING + : DataApiBuilderException.OBO_TENANT_CLAIM_MISSING; + Assert.AreEqual(expectedMessage, ex.Message); + } + + #endregion + + #region Claim Extraction Tests + + /// + /// Verifies that 'oid' claim is preferred over 'sub' when both are present. + /// + [TestMethod] + public async Task GetAccessTokenOnBehalfOfAsync_PrefersOidOverSub() + { + // Arrange + SetupMsalSuccess(); + + ClaimsIdentity identity = new(); + identity.AddClaim(new Claim("oid", TEST_SUBJECT_OID)); + identity.AddClaim(new Claim("sub", TEST_SUBJECT_SUB)); + identity.AddClaim(new Claim("tid", TEST_TENANT_ID)); + ClaimsPrincipal principal = new(identity); + + // Act + string result = await _provider.GetAccessTokenOnBehalfOfAsync( + principal: principal, + incomingJwtAssertion: TEST_INCOMING_JWT, + databaseAudience: TEST_DATABASE_AUDIENCE); + + // Assert + Assert.IsNotNull(result); + Assert.AreEqual(TEST_ACCESS_TOKEN, result); + } + + /// + /// Verifies that 'sub' claim is used when 'oid' is not present. + /// + [TestMethod] + public async Task GetAccessTokenOnBehalfOfAsync_FallsBackToSub_WhenOidMissing() + { + // Arrange + SetupMsalSuccess(); + + ClaimsIdentity identity = new(); + identity.AddClaim(new Claim("sub", TEST_SUBJECT_SUB)); + identity.AddClaim(new Claim("tid", TEST_TENANT_ID)); + ClaimsPrincipal principal = new(identity); + + // Act + string result = await _provider.GetAccessTokenOnBehalfOfAsync( + principal: principal, + incomingJwtAssertion: TEST_INCOMING_JWT, + databaseAudience: TEST_DATABASE_AUDIENCE); + + // Assert + Assert.IsNotNull(result); + Assert.AreEqual(TEST_ACCESS_TOKEN, result); + } + + #endregion + + #region Token Caching Tests + + /// + /// Verifies that tokens are cached and reused for identical requests. + /// FusionCache handles caching - factory should only be called once. + /// + [TestMethod] + public async Task GetAccessTokenOnBehalfOfAsync_CachesToken_AndReturnsCachedOnSecondCall() + { + // Arrange + int msalCallCount = 0; + _msalMock + .Setup(m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback(() => msalCallCount++) + .ReturnsAsync(CreateAuthenticationResult(TEST_ACCESS_TOKEN, DateTimeOffset.UtcNow.AddMinutes(30))); + + ClaimsPrincipal principal = CreatePrincipalWithOid(TEST_SUBJECT_OID, TEST_TENANT_ID); + + // Act + string result1 = await _provider.GetAccessTokenOnBehalfOfAsync(principal, TEST_INCOMING_JWT, TEST_DATABASE_AUDIENCE); + string result2 = await _provider.GetAccessTokenOnBehalfOfAsync(principal, TEST_INCOMING_JWT, TEST_DATABASE_AUDIENCE); + + // Assert + Assert.IsNotNull(result1); + Assert.AreEqual(result1, result2); + Assert.AreEqual(1, msalCallCount, "MSAL should only be called once due to FusionCache caching."); + } + + /// + /// Verifies that different roles produce different cache keys. + /// + [TestMethod] + public async Task GetAccessTokenOnBehalfOfAsync_DifferentRoles_ProducesDifferentCacheKeys() + { + // Arrange + int msalCallCount = 0; + _msalMock + .Setup(m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback(() => msalCallCount++) + .ReturnsAsync(() => CreateAuthenticationResult($"token-{msalCallCount}", DateTimeOffset.UtcNow.AddMinutes(30))); + + ClaimsPrincipal principalReader = CreatePrincipalWithRoles(TEST_SUBJECT_OID, TEST_TENANT_ID, "reader"); + ClaimsPrincipal principalWriter = CreatePrincipalWithRoles(TEST_SUBJECT_OID, TEST_TENANT_ID, "writer"); + + // Act + string resultReader = await _provider.GetAccessTokenOnBehalfOfAsync(principalReader, TEST_INCOMING_JWT, TEST_DATABASE_AUDIENCE); + string resultWriter = await _provider.GetAccessTokenOnBehalfOfAsync(principalWriter, TEST_INCOMING_JWT, TEST_DATABASE_AUDIENCE); + + // Assert + Assert.IsNotNull(resultReader); + Assert.IsNotNull(resultWriter); + Assert.AreNotEqual(resultReader, resultWriter, "Different roles should produce different tokens."); + Assert.AreEqual(2, msalCallCount, "Different roles should produce different cache keys, requiring two MSAL calls."); + } + + #endregion + + #region Scope Formatting Tests + + /// + /// Verifies that scope is correctly formatted from audience (with or without trailing slash). + /// + [DataTestMethod] + [DataRow("https://database.windows.net/", "https://database.windows.net/.default", DisplayName = "With trailing slash")] + [DataRow("https://database.windows.net", "https://database.windows.net/.default", DisplayName = "Without trailing slash")] + public async Task GetAccessTokenOnBehalfOfAsync_FormatsScope_CorrectlyFromAudience( + string databaseAudience, string expectedScope) + { + // Arrange + string capturedScope = null; + _msalMock + .Setup(m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((scopes, _, _) => capturedScope = scopes[0]) + .ReturnsAsync(CreateAuthenticationResult(TEST_ACCESS_TOKEN, DateTimeOffset.UtcNow.AddMinutes(30))); + + ClaimsPrincipal principal = CreatePrincipalWithOid(TEST_SUBJECT_OID, TEST_TENANT_ID); + + // Act + await _provider.GetAccessTokenOnBehalfOfAsync(principal, TEST_INCOMING_JWT, databaseAudience); + + // Assert + Assert.AreEqual(expectedScope, capturedScope); + } + + #endregion + + #region Error Handling Tests + + /// + /// Verifies that MSAL exceptions are wrapped in DataApiBuilderException with Unauthorized status. + /// + [TestMethod] + public async Task GetAccessTokenOnBehalfOfAsync_MsalException_ThrowsUnauthorized() + { + // Arrange + _msalMock + .Setup(m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ThrowsAsync(new MsalServiceException("invalid_grant", "The user or admin has not consented.")); + + ClaimsPrincipal principal = CreatePrincipalWithOid(TEST_SUBJECT_OID, TEST_TENANT_ID); + + // Act & Assert + DataApiBuilderException ex = await Assert.ThrowsExceptionAsync( + async () => await _provider.GetAccessTokenOnBehalfOfAsync(principal, TEST_INCOMING_JWT, TEST_DATABASE_AUDIENCE)); + + Assert.AreEqual(HttpStatusCode.Unauthorized, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure, ex.SubStatusCode); + } + + #endregion + + #region Helper Methods + + /// + /// Creates an in-memory FusionCache instance for testing. + /// + private static IFusionCache CreateFusionCache() + { + return new FusionCache(new FusionCacheOptions + { + DefaultEntryOptions = new FusionCacheEntryOptions + { + Duration = TimeSpan.FromMinutes(30) + } + }); + } + + /// + /// Sets up the class-level MSAL mock to return a successful token acquisition result. + /// + private void SetupMsalSuccess(int tokenExpiryMinutes = 30) + { + _msalMock + .Setup(m => m.AcquireTokenOnBehalfOfAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(CreateAuthenticationResult(TEST_ACCESS_TOKEN, DateTimeOffset.UtcNow.AddMinutes(tokenExpiryMinutes))); + } + + /// + /// Creates a ClaimsPrincipal with oid and tid claims. + /// + private static ClaimsPrincipal CreatePrincipalWithOid(string oid, string tid) + { + ClaimsIdentity identity = new(); + identity.AddClaim(new Claim("oid", oid)); + identity.AddClaim(new Claim("tid", tid)); + return new ClaimsPrincipal(identity); + } + + /// + /// Creates a ClaimsPrincipal with oid, tid, and role claims. + /// + private static ClaimsPrincipal CreatePrincipalWithRoles(string oid, string tid, params string[] roles) + { + ClaimsIdentity identity = new(); + identity.AddClaim(new Claim("oid", oid)); + identity.AddClaim(new Claim("tid", tid)); + foreach (string role in roles) + { + identity.AddClaim(new Claim(AuthenticationOptions.ROLE_CLAIM_TYPE, role)); + } + + return new ClaimsPrincipal(identity); + } + + /// + /// Creates a mock AuthenticationResult for testing. + /// + private static AuthenticationResult CreateAuthenticationResult(string accessToken, DateTimeOffset expiresOn) + { + return new AuthenticationResult( + accessToken: accessToken, + isExtendedLifeTimeToken: false, + uniqueId: Guid.NewGuid().ToString(), + expiresOn: expiresOn, + extendedExpiresOn: expiresOn, + tenantId: TEST_TENANT_ID, + account: null, + idToken: null, + scopes: new[] { $"{TEST_DATABASE_AUDIENCE}.default" }, + correlationId: Guid.NewGuid()); + } + + #endregion + } +} diff --git a/src/Service.Tests/Azure.DataApiBuilder.Service.Tests.csproj b/src/Service.Tests/Azure.DataApiBuilder.Service.Tests.csproj index d250822359..ae274a4dc2 100644 --- a/src/Service.Tests/Azure.DataApiBuilder.Service.Tests.csproj +++ b/src/Service.Tests/Azure.DataApiBuilder.Service.Tests.csproj @@ -97,6 +97,7 @@ + diff --git a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs index 1294c009da..d8cd279d95 100644 --- a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs +++ b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs @@ -416,4 +416,48 @@ private static string GetRawConfigJson(string globalCacheConfig, string entityCa return expectedRuntimeConfigJson.ToString(); } + + /// + /// Regression test: Validates that when global runtime cache is enabled but entity cache is disabled, + /// GetEntityCacheEntryTtl and GetEntityCacheEntryLevel do not throw and return sensible defaults. + /// Previously, these methods threw a DataApiBuilderException (BadRequest/NotSupported) when the entity + /// had caching disabled, which caused 400 errors for valid requests when the global cache was enabled. + /// These methods are now pure accessors that always return a value regardless of cache enablement. + /// + /// Global cache configuration JSON fragment. + /// Entity cache configuration JSON fragment. + /// Expected TTL returned by GetEntityCacheEntryTtl. + /// Expected cache level returned by GetEntityCacheEntryLevel. + [DataRow(@",""cache"": { ""enabled"": true, ""ttl-seconds"": 10 }", @",""cache"": { ""enabled"": false }", 10, EntityCacheLevel.L1L2, DisplayName = "Global cache enabled with custom TTL, entity cache disabled: entity returns global TTL and default level.")] + [DataRow(@",""cache"": { ""enabled"": true }", @",""cache"": { ""enabled"": false }", 5, EntityCacheLevel.L1L2, DisplayName = "Global cache enabled with default TTL, entity cache disabled: entity returns default TTL and default level.")] + [DataRow(@",""cache"": { ""enabled"": true, ""ttl-seconds"": 10 }", @"", 10, EntityCacheLevel.L1L2, DisplayName = "Global cache enabled with custom TTL, entity cache omitted: entity returns global TTL and default level.")] + [DataTestMethod] + public void GetEntityCacheEntryTtlAndLevel_DoesNotThrow_WhenRuntimeCacheEnabledAndEntityCacheDisabled( + string globalCacheConfig, + string entityCacheConfig, + int expectedTtl, + EntityCacheLevel expectedLevel) + { + // Arrange + string fullConfig = GetRawConfigJson(globalCacheConfig: globalCacheConfig, entityCacheConfig: entityCacheConfig); + RuntimeConfigLoader.TryParseConfig( + json: fullConfig, + out RuntimeConfig? config, + replacementSettings: null); + + Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); + Assert.IsTrue(config.IsCachingEnabled, message: "Global caching should be enabled for this test scenario."); + + Entity entity = config.Entities.First().Value; + Assert.IsFalse(entity.IsCachingEnabled, message: "Entity caching should be disabled for this test scenario."); + + string entityName = config.Entities.First().Key; + + // Act & Assert - These calls must not throw. + int actualTtl = config.GetEntityCacheEntryTtl(entityName); + EntityCacheLevel actualLevel = config.GetEntityCacheEntryLevel(entityName); + + Assert.AreEqual(expected: expectedTtl, actual: actualTtl, message: "GetEntityCacheEntryTtl should return the global/default TTL when entity cache is disabled."); + Assert.AreEqual(expected: expectedLevel, actual: actualLevel, message: "GetEntityCacheEntryLevel should return the default level when entity cache is disabled."); + } } diff --git a/src/Service.Tests/Caching/DabCacheServiceIntegrationTests.cs b/src/Service.Tests/Caching/DabCacheServiceIntegrationTests.cs index 68c9225b96..91c6ef28bd 100644 --- a/src/Service.Tests/Caching/DabCacheServiceIntegrationTests.cs +++ b/src/Service.Tests/Caching/DabCacheServiceIntegrationTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Data.Common; +using System.IO.Abstractions; using System.Net; using System.Reflection; using System.Text.Json; @@ -699,10 +700,14 @@ private static Mock CreateMockSqlQueryStructure(string entity entityToDatabaseObject.Add(entityName, new DatabaseTable()); Mock mockRuntimeConfigProvider = CreateMockRuntimeConfigProvider(entityName); + IFileSystem fileSystem = new FileSystem(); + Mock> loggerValidator = new(); + RuntimeConfigValidator runtimeConfigValidator = new(mockRuntimeConfigProvider.Object, fileSystem, loggerValidator.Object); Mock mockQueryFactory = new(); Mock> mockLogger = new(); Mock mockSqlMetadataProvider = new( mockRuntimeConfigProvider.Object, + runtimeConfigValidator, mockQueryFactory.Object, mockLogger.Object, dataSourceName, diff --git a/src/Service.Tests/Configuration/CompressionIntegrationTests.cs b/src/Service.Tests/Configuration/CompressionIntegrationTests.cs new file mode 100644 index 0000000000..f578403569 --- /dev/null +++ b/src/Service.Tests/Configuration/CompressionIntegrationTests.cs @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.ResponseCompression; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using DabCompressionLevel = Azure.DataApiBuilder.Config.ObjectModel.CompressionLevel; +using SystemCompressionLevel = System.IO.Compression.CompressionLevel; + +namespace Azure.DataApiBuilder.Service.Tests.Configuration +{ + /// + /// Integration tests for HTTP response compression middleware. + /// Validates that compression reduces payload sizes and doesn't break existing functionality. + /// + [TestClass] + public class CompressionIntegrationTests + { + // Sample JSON payload for testing compression + private static readonly string _sampleJsonPayload = JsonSerializer.Serialize(new + { + data = Enumerable.Range(1, 100).Select(i => new + { + id = i, + title = $"Book Title {i}", + author = $"Author Name {i}", + description = $"This is a long description for book {i} to ensure we have enough data to compress effectively. " + + "Compression works best with repetitive text and structured data like JSON." + }) + }); + + #region Positive Tests + + /// + /// Verify that responses are compressed when client sends Accept-Encoding header with gzip. + /// + [TestMethod("Responses are compressed with gzip when Accept-Encoding header is present.")] + public async Task TestResponseIsCompressedWithGzip() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Optimal); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "gzip"; + }); + + // Verify Content-Encoding header is present + Assert.IsTrue(returnContext.Response.Headers.ContentEncoding.Contains("gzip"), + "Response should have gzip Content-Encoding header"); + + // Verify response body exists by checking if we can read it + using (var reader = new StreamReader(returnContext.Response.Body)) + { + string content = await reader.ReadToEndAsync(); + Assert.IsTrue(content.Length > 0, "Response body should not be empty"); + } + } + + /// + /// Verify that responses are compressed with Brotli when client requests it. + /// + [TestMethod("Responses are compressed with Brotli when Accept-Encoding header specifies br.")] + public async Task TestResponseIsCompressedWithBrotli() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Optimal); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "br"; + }); + + Assert.IsTrue(returnContext.Response.Headers.ContentEncoding.Contains("br"), + "Response should have br Content-Encoding header"); + } + + /// + /// Verify that compression reduces payload size significantly. + /// + [TestMethod("Compression reduces payload size for JSON responses.")] + public async Task TestCompressionReducesPayloadSize() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Optimal); + TestServer server = host.GetTestServer(); + + // Get uncompressed response + HttpContext uncompressedContext = await server.SendAsync(context => + { + // Don't set Accept-Encoding + }); + + using (var ms = new MemoryStream()) + { + await uncompressedContext.Response.Body.CopyToAsync(ms); + long uncompressedSize = ms.Length; + + // Get compressed response + HttpContext compressedContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "gzip"; + }); + + using (var cms = new MemoryStream()) + { + await compressedContext.Response.Body.CopyToAsync(cms); + long compressedSize = cms.Length; + + // Verify compressed size is smaller + Assert.IsTrue(compressedSize < uncompressedSize, + $"Compressed size ({compressedSize}) should be less than uncompressed size ({uncompressedSize})"); + + // Calculate compression ratio + double compressionRatio = (double)(uncompressedSize - compressedSize) / uncompressedSize * 100; + Console.WriteLine($"Compression achieved: {compressionRatio:F2}% reduction (from {uncompressedSize} to {compressedSize} bytes)"); + + // Verify at least some compression occurred (at least 10% for JSON) + Assert.IsTrue(compressionRatio > 10, $"Compression ratio should be at least 10%, got {compressionRatio:F2}%"); + } + } + } + + /// + /// Verify that compression is disabled when level is set to None. + /// + [TestMethod("Responses are not compressed when compression level is None.")] + public async Task TestCompressionDisabledWhenLevelIsNone() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.None); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "gzip"; + }); + + Assert.IsFalse(returnContext.Response.Headers.ContentEncoding.Any(), + "Response should not have Content-Encoding header when compression is disabled"); + } + + /// + /// Verify that responses are not compressed when client doesn't send Accept-Encoding. + /// + [TestMethod("Responses are not compressed without Accept-Encoding header.")] + public async Task TestNoCompressionWithoutAcceptEncoding() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Optimal); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + // Don't set Accept-Encoding header + }); + + Assert.IsFalse(returnContext.Response.Headers.ContentEncoding.Any(), + "Response should not be compressed without Accept-Encoding header"); + } + + /// + /// Verify that fastest compression level works correctly. + /// + [TestMethod("Compression works with fastest level.")] + public async Task TestCompressionWithFastestLevel() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Fastest); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "gzip"; + }); + + Assert.IsTrue(returnContext.Response.Headers.ContentEncoding.Contains("gzip"), + "Response should be compressed with fastest level"); + } + + /// + /// Verify that compressed content can be decompressed correctly. + /// + [TestMethod("Compressed content can be decompressed and is valid JSON.")] + public async Task TestCompressedContentCanBeDecompressed() + { + IHost host = await CreateCompressionConfiguredWebHost(DabCompressionLevel.Optimal); + TestServer server = host.GetTestServer(); + + HttpContext returnContext = await server.SendAsync(context => + { + context.Request.Headers.AcceptEncoding = "gzip"; + }); + + // Read compressed data + using (var ms = new MemoryStream()) + { + await returnContext.Response.Body.CopyToAsync(ms); + byte[] compressedData = ms.ToArray(); + + // Decompress + string decompressedContent = await DecompressGzipAsync(compressedData); + Assert.IsFalse(string.IsNullOrEmpty(decompressedContent), "Decompressed content should not be empty"); + + // Verify it's valid JSON matching our sample + JsonDocument doc = JsonDocument.Parse(decompressedContent); + Assert.IsTrue(doc.RootElement.TryGetProperty("data", out _), "Decompressed JSON should contain 'data' property"); + } + } + + #endregion + + #region Helper Methods + + /// + /// Creates a minimal compression-configured WebHost for testing. + /// + private static async Task CreateCompressionConfiguredWebHost(DabCompressionLevel level) + { + return await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddHttpContextAccessor(); + + // Add response compression based on level + if (level != DabCompressionLevel.None) + { + SystemCompressionLevel systemLevel = level switch + { + DabCompressionLevel.Fastest => SystemCompressionLevel.Fastest, + DabCompressionLevel.Optimal => SystemCompressionLevel.Optimal, + _ => SystemCompressionLevel.Optimal + }; + + services.AddResponseCompression(options => + { + options.EnableForHttps = true; + options.Providers.Add(); + options.Providers.Add(); + }); + + services.Configure(options => + { + options.Level = systemLevel; + }); + + services.Configure(options => + { + options.Level = systemLevel; + }); + } + }) + .Configure(app => + { + // Add response compression middleware + if (level != DabCompressionLevel.None) + { + app.UseResponseCompression(); + } + + // Simple endpoint that returns JSON + app.Run(async context => + { + context.Response.ContentType = "application/json"; + await context.Response.WriteAsync(_sampleJsonPayload); + }); + }); + }) + .StartAsync(); + } + + /// + /// Decompresses gzip-compressed data. + /// + private static async Task DecompressGzipAsync(byte[] data) + { + using (var compressedStream = new MemoryStream(data)) + using (var gzipStream = new GZipStream(compressedStream, CompressionMode.Decompress)) + using (var resultStream = new MemoryStream()) + { + await gzipStream.CopyToAsync(resultStream); + return Encoding.UTF8.GetString(resultStream.ToArray()); + } + } + + #endregion + } +} diff --git a/src/Service.Tests/Configuration/ConfigurationTests.cs b/src/Service.Tests/Configuration/ConfigurationTests.cs index 9df54be519..0ef9b67a4b 100644 --- a/src/Service.Tests/Configuration/ConfigurationTests.cs +++ b/src/Service.Tests/Configuration/ConfigurationTests.cs @@ -1966,6 +1966,63 @@ public void TestBasicConfigSchemaWithNoEntityFieldsIsInvalid() Assert.IsTrue(result.ErrorMessage.Contains("Total schema validation errors: 1\n> Required properties are missing from object: entities.")); } + /// + /// Validates that the JSON schema correctly validates entity cache configuration properties. + /// Tests both valid configurations (proper level values, ttl >= 1) and invalid configurations + /// (invalid level values, ttl = 0). + /// + [DataTestMethod] + [DataRow("L1", 10, true, DisplayName = "Valid cache config with L1 and ttl=10")] + [DataRow("L1L2", 1, true, DisplayName = "Valid cache config with L1L2 and minimum ttl=1")] + [DataRow("L1L2", 3600, true, DisplayName = "Valid cache config with L1L2 and ttl=3600")] + [DataRow("L3", 10, false, DisplayName = "Invalid cache config with invalid level L3")] + [DataRow("L1", 0, false, DisplayName = "Invalid cache config with ttl=0 (below minimum)")] + [DataRow("L1L2", -1, false, DisplayName = "Invalid cache config with negative ttl")] + public void TestEntityCacheSchemaValidation(string level, int ttlSeconds, bool shouldBeValid) + { + string jsonData = $@"{{ + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": {{ + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }}, + ""entities"": {{ + ""Book"": {{ + ""source"": {{ + ""object"": ""books"", + ""type"": ""table"" + }}, + ""permissions"": [{{ + ""role"": ""anonymous"", + ""actions"": [""read""] + }}], + ""cache"": {{ + ""enabled"": true, + ""ttl-seconds"": {ttlSeconds}, + ""level"": ""{level}"" + }} + }} + }} + }}"; + + Mock> schemaValidatorLogger = new(); + string jsonSchema = File.ReadAllText("dab.draft.schema.json"); + JsonConfigSchemaValidator jsonSchemaValidator = new(schemaValidatorLogger.Object, new MockFileSystem()); + + JsonSchemaValidationResult result = jsonSchemaValidator.ValidateJsonConfigWithSchema(jsonSchema, jsonData); + + if (shouldBeValid) + { + Assert.IsTrue(result.IsValid, $"Expected valid config but got errors: {result.ErrorMessage}"); + Assert.IsTrue(EnumerableUtilities.IsNullOrEmpty(result.ValidationErrors)); + } + else + { + Assert.IsFalse(result.IsValid, "Expected validation to fail but it passed"); + Assert.IsFalse(EnumerableUtilities.IsNullOrEmpty(result.ValidationErrors)); + } + } + /// /// This test tries to validate a runtime config file that is not compliant with the runtime config JSON schema. /// It validates no additional properties are defined in the config file. @@ -2846,7 +2903,7 @@ public async Task ValidateErrorMessageForMutationWithoutReadPermission() }"; string queryName = "stock_by_pk"; - ValidateMutationSucceededAtDbLayer(server, client, graphQLQuery, queryName, authToken, AuthorizationResolver.ROLE_AUTHENTICATED); + await ValidateMutationSucceededAtDbLayer(server, client, graphQLQuery, queryName, authToken, AuthorizationResolver.ROLE_AUTHENTICATED); } finally { @@ -3168,7 +3225,7 @@ public async Task ValidateInheritanceOfReadPermissionFromAnonymous() /// GraphQL query/mutation text /// GraphQL query/mutation name /// Auth token for the graphQL request - private static async void ValidateMutationSucceededAtDbLayer(TestServer server, HttpClient client, string query, string queryName, string authToken, string clientRoleHeader) + private static async Task ValidateMutationSucceededAtDbLayer(TestServer server, HttpClient client, string query, string queryName, string authToken, string clientRoleHeader) { JsonElement queryResponse = await GraphQLRequestExecutor.PostGraphQLRequestAsync( client, @@ -3180,6 +3237,7 @@ private static async void ValidateMutationSucceededAtDbLayer(TestServer server, clientRoleHeader: clientRoleHeader); Assert.IsNotNull(queryResponse); + Assert.AreNotEqual(JsonValueKind.Null, queryResponse.ValueKind, "Expected a JSON object response but received null."); Assert.IsFalse(queryResponse.TryGetProperty("errors", out _)); } @@ -3624,8 +3682,11 @@ public void ValidateGraphQLSchemaForCircularReference(string schema) FileSystemRuntimeConfigLoader loader = new(fileSystem); RuntimeConfigProvider provider = new(loader); + Mock> loggerValidator = new(); + RuntimeConfigValidator validator = new(provider, fileSystem, loggerValidator.Object); + DataApiBuilderException exception = - Assert.ThrowsException(() => new CosmosSqlMetadataProvider(provider, fileSystem)); + Assert.ThrowsException(() => new CosmosSqlMetadataProvider(provider, validator, fileSystem)); Assert.AreEqual("Circular reference detected in the provided GraphQL schema for entity 'Character'.", exception.Message); Assert.AreEqual(HttpStatusCode.InternalServerError, exception.StatusCode); Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); @@ -3675,9 +3736,11 @@ type Planet @model(name:""PlanetAlias"") { }); FileSystemRuntimeConfigLoader loader = new(fileSystem); RuntimeConfigProvider provider = new(loader); + Mock> loggerValidator = new(); + RuntimeConfigValidator validator = new(provider, fileSystem, loggerValidator.Object); DataApiBuilderException exception = - Assert.ThrowsException(() => new CosmosSqlMetadataProvider(provider, fileSystem)); + Assert.ThrowsException(() => new CosmosSqlMetadataProvider(provider, validator, fileSystem)); Assert.AreEqual("The entity 'Character' was not found in the runtime config.", exception.Message); Assert.AreEqual(HttpStatusCode.ServiceUnavailable, exception.StatusCode); Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, exception.SubStatusCode); @@ -5231,6 +5294,346 @@ public async Task TestGraphQLIntrospectionQueriesAreNotImpactedByDepthLimit() } } + /// + /// + /// + /// + /// + /// + [TestCategory(TestCategory.MSSQL)] + [DataTestMethod] + [DataRow(true, 4, DisplayName = "Test Autoentities with additional entities")] + [DataRow(false, 2, DisplayName = "Test Autoentities without additional entities")] + public async Task TestAutoentitiesAreGeneratedIntoEntities(bool useEntities, int expectedEntityCount) + { + // Arrange + EntityRelationship bookRelationship = new(Cardinality: Cardinality.One, + TargetEntity: "BookPublisher", + SourceFields: new string[] { }, + TargetFields: new string[] { }, + LinkingObject: null, + LinkingSourceFields: null, + LinkingTargetFields: null); + + Entity bookEntity = new(Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + Rest: null, + GraphQL: new(Singular: "book", Plural: "books"), + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) }, + Relationships: new Dictionary() { { "publishers", bookRelationship } }, + Mappings: null); + + EntityRelationship publisherRelationship = new(Cardinality: Cardinality.Many, + TargetEntity: "Book", + SourceFields: new string[] { }, + TargetFields: new string[] { }, + LinkingObject: null, + LinkingSourceFields: null, + LinkingTargetFields: null); + + Entity publisherEntity = new( + Source: new("publishers", EntitySourceType.Table, null, null), + Fields: null, + Rest: null, + GraphQL: new(Singular: "bookpublisher", Plural: "bookpublishers"), + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) }, + Relationships: new Dictionary() { { "books", publisherRelationship } }, + Mappings: null); + + Dictionary entityMap = new() + { + { "Book", bookEntity }, + { "BookPublisher", publisherEntity } + }; + + Dictionary autoentityMap = new() + { + { + "PublisherAutoEntity", new Autoentity( + Patterns: new AutoentityPatterns( + Include: new[] { "%publishers%" }, + Exclude: null, + Name: null + ), + Template: new AutoentityTemplate( + Rest: new EntityRestOptions(Enabled: true), + GraphQL: new EntityGraphQLOptions( + Singular: string.Empty, + Plural: string.Empty, + Enabled: true + ), + Health: null, + Cache: null + ), + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) } + ) + } + }; + + // Create DataSource for MSSQL connection + DataSource dataSource = new(DatabaseType.MSSQL, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL), Options: null); + + // Build complete runtime configuration with autoentities + RuntimeConfig configuration = new( + Schema: "TestAutoentitiesSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(Enabled: true), + GraphQL: new(Enabled: true), + Mcp: new(Enabled: false), + Host: new( + Cors: null, + Authentication: new Config.ObjectModel.AuthenticationOptions( + Provider: nameof(EasyAuthType.StaticWebApps), + Jwt: null + ) + ) + ), + Entities: new(useEntities ? entityMap : new Dictionary()), + Autoentities: new RuntimeAutoentities(autoentityMap) + ); + + File.WriteAllText(CUSTOM_CONFIG_FILENAME, configuration.ToJson()); + + string[] args = new[] { $"--ConfigFileName={CUSTOM_CONFIG_FILENAME}" }; + + using (TestServer server = new(Program.CreateWebHostBuilder(args))) + using (HttpClient client = server.CreateClient()) + { + // Act + RuntimeConfigProvider configProvider = server.Services.GetService(); + using HttpRequestMessage restRequest = new(HttpMethod.Get, "/api/publishers"); + using HttpResponseMessage restResponse = await client.SendAsync(restRequest); + + string graphqlQuery = @" + { + publishers { + items { + id + name + } + } + }"; + + object graphqlPayload = new { query = graphqlQuery }; + HttpRequestMessage graphqlRequest = new(HttpMethod.Post, "/graphql") + { + Content = JsonContent.Create(graphqlPayload) + }; + HttpResponseMessage graphqlResponse = await client.SendAsync(graphqlRequest); + + // Assert + string expectedResponseFragment = @"{""id"":1156,""name"":""The First Publisher""}"; + + // Verify number of entities + Assert.AreEqual(expectedEntityCount, configProvider.GetConfig().Entities.Entities.Count, "Number of generated entities is not what is expected"); + + // Verify REST response + Assert.AreEqual(HttpStatusCode.OK, restResponse.StatusCode, "REST request to auto-generated entity should succeed"); + + string restResponseBody = await restResponse.Content.ReadAsStringAsync(); + Assert.IsTrue(!string.IsNullOrEmpty(restResponseBody), "REST response should contain data"); + Assert.IsTrue(restResponseBody.Contains(expectedResponseFragment)); + + // Verify GraphQL response + Assert.AreEqual(HttpStatusCode.OK, graphqlResponse.StatusCode, "GraphQL request to auto-generated entity should succeed"); + + string graphqlResponseBody = await graphqlResponse.Content.ReadAsStringAsync(); + Assert.IsTrue(!string.IsNullOrEmpty(graphqlResponseBody), "GraphQL response should contain data"); + Assert.IsFalse(graphqlResponseBody.Contains("errors"), "GraphQL response should not contain errors"); + Assert.IsTrue(graphqlResponseBody.Contains(expectedResponseFragment)); + } + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + [TestCategory(TestCategory.MSSQL)] + [DataTestMethod] + [DataRow("publishers", "uniqueSingularPublisher", "uniquePluralPublishers", "/unique/publisher", "Entity with name 'publishers' already exists. Cannot create new entity from autoentity pattern with definition-name 'PublisherAutoEntity'.", DisplayName = "Autoentities fail due to entity name")] + [DataRow("UniquePublisher", "publishers", "uniquePluralPublishers", "/unique/publisher", "Entity publishers generates queries/mutation that already exist", DisplayName = "Autoentities fail due to graphql singular type")] + [DataRow("UniquePublisher", "uniqueSingularPublisher", "publishers", "/unique/publisher", "Entity publishers generates queries/mutation that already exist", DisplayName = "Autoentities fail due to graphql plural type")] + [DataRow("UniquePublisher", "uniqueSingularPublisher", "uniquePluralPublishers", "/publishers", "The rest path: publishers specified for entity: publishers is already used by another entity.", DisplayName = "Autoentities fail due to rest path")] + public async Task ValidateAutoentityGenerationConflicts(string entityName, string singular, string plural, string path, string exceptionMessage) + { + // Arrange + Entity publisherEntity = new( + Source: new("publishers", EntitySourceType.Table, null, null), + Fields: null, + Rest: new(Path: path), + GraphQL: new(Singular: singular, Plural: plural), + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) }, + Relationships: null, + Mappings: null); + + Dictionary entityMap = new() + { + { entityName, publisherEntity } + }; + + Dictionary autoentityMap = new() + { + { + "PublisherAutoEntity", new Autoentity( + Patterns: new AutoentityPatterns( + Include: new[] { "%publishers%" }, + Exclude: null, + Name: null + ), + Template: new AutoentityTemplate( + Rest: new EntityRestOptions( + Enabled: true), + GraphQL: new EntityGraphQLOptions( + Singular: string.Empty, + Plural: string.Empty, + Enabled: true + ), + Health: null, + Cache: null + ), + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) } + ) + } + }; + + // Create DataSource for MSSQL connection + DataSource dataSource = new(DatabaseType.MSSQL, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL), Options: null); + + // Build complete runtime configuration with autoentities + RuntimeConfig configuration = new( + Schema: "TestAutoentitiesSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(Enabled: true), + GraphQL: new(Enabled: true), + Mcp: new(Enabled: false), + Host: new( + Mode: HostMode.Development, + Cors: null, + Authentication: new Config.ObjectModel.AuthenticationOptions( + Provider: nameof(EasyAuthType.StaticWebApps), + Jwt: null + ) + ) + ), + Entities: new(entityMap), + Autoentities: new RuntimeAutoentities(autoentityMap) + ); + + File.WriteAllText(CUSTOM_CONFIG_FILENAME, configuration.ToJson()); + + ILoggerFactory loggerFactory = new LoggerFactory(); + IFileSystem fileSystem = new FileSystem(); + + FileSystemRuntimeConfigLoader configLoader = new(fileSystem) + { + RuntimeConfig = configuration + }; + + RuntimeConfigProvider configProvider = new(configLoader); + + RuntimeConfigValidator configValidator = new(configProvider, fileSystem, loggerFactory.CreateLogger()); + + QueryManagerFactory queryManagerFactory = new( + runtimeConfigProvider: configProvider, + logger: loggerFactory.CreateLogger(), + contextAccessor: null!, + handler: null); + + MsSqlMetadataProvider provider = new( + configProvider, + configValidator, + queryManagerFactory, + loggerFactory.CreateLogger(), + configLoader.RuntimeConfig.DefaultDataSourceName, + false); + + try + { + await provider.InitializeAsync(); + Assert.Fail("It is expected for DAB to fail due to entities not containing unique parameters."); + } + catch (DataApiBuilderException ex) + { + Assert.AreEqual(exceptionMessage, ex.Message); + } + } + + /// + /// Validates the autoentity configuration inside the configuration file and also + /// validates that entities created from the autoentity configuration do not generate + /// duplicate entities and paths for REST and GraphQL. + /// + /// + [TestCategory(TestCategory.MSSQL)] + [TestMethod] + public async Task ValidateAutoentitiesConfiguration() + { + EntityAction entityAction = new(EntityActionOperation.Read, null, null); + + Dictionary autoentityMap = new(); + string autoentityName = "AutoentityA"; + + Autoentity autoentity = new( + Patterns: new AutoentityPatterns( + Include: new[] { "%patterns%" }, + Exclude: new[] { "%books%" }, + Name: "{object}"), + Template: new AutoentityTemplate( + Rest: new(Enabled: false), + GraphQL: new(Enabled: true, Singular: string.Empty, Plural: string.Empty), + Health: new(enabled: true), + Cache: new(Enabled: true, TtlSeconds: 50)), + Permissions: new EntityPermission[] { new("anonymous", new EntityAction[] { entityAction }) }); + + autoentityMap.Add(autoentityName, autoentity); + + DataSource dataSource = new(DatabaseType.MSSQL, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL), Options: null); + + RuntimeConfig runtimeConfig = new( + Schema: "TestAutoentitiesSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null, HostMode.Development)), + Entities: new(new Dictionary()), + Autoentities: new(autoentityMap)); + + const string CUSTOM_CONFIG = "autoentities-validation-config.json"; + + File.WriteAllText(CUSTOM_CONFIG, runtimeConfig.ToJson()); + IFileSystem fileSystem = new FileSystem(); + + FileSystemRuntimeConfigLoader loader = new(fileSystem) + { + RuntimeConfig = runtimeConfig + }; + + RuntimeConfigProvider provider = new(loader); + Mock> loggerMock = new(); + RuntimeConfigValidator configValidator = new(provider, fileSystem, loggerMock.Object); + + try + { + await configValidator.TryValidateConfig(CUSTOM_CONFIG, TestHelper.ProvisionLoggerFactory()); + } + catch (Exception ex) + { + Assert.Fail(ex.Message); + } + } + /// /// Tests the behavior of GraphQL queries in non-hosted mode when the depth limit is explicitly set to -1 or null. /// Setting the depth limit to -1 is intended to disable the depth limit check, allowing queries of any depth. @@ -5690,18 +6093,26 @@ public static async Task GetMcpResponse(HttpClient httpClient, M HttpStatusCode responseCode = HttpStatusCode.ServiceUnavailable; while (retryCount < RETRY_COUNT) { - // Minimal MCP request (list tools) – valid JSON-RPC request + // Minimal MCP request (initialize) - valid JSON-RPC request. + // Using 'initialize' because 'tools/list' requires an active session + // in the MCP Streamable HTTP transport (ModelContextProtocol 1.0.0). object payload = new { jsonrpc = "2.0", id = 1, - method = "tools/list" + method = "initialize", + @params = new + { + protocolVersion = "2025-03-26", + capabilities = new { }, + clientInfo = new { name = "dab-test", version = "1.0.0" } + } }; HttpRequestMessage mcpRequest = new(HttpMethod.Post, mcp.Path) { Content = JsonContent.Create(payload) }; - mcpRequest.Headers.Add("Accept", "*/*"); + mcpRequest.Headers.Add("Accept", "application/json, text/event-stream"); HttpResponseMessage mcpResponse = await httpClient.SendAsync(mcpRequest); responseCode = mcpResponse.StatusCode; diff --git a/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs b/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs index 7dcf837d08..d6f19ec65f 100644 --- a/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs +++ b/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs @@ -101,4 +101,34 @@ public async Task FailLoadMultiDataSourceConfigDuplicateEntities(string configPa Assert.IsTrue(error.StartsWith("Deserialization of the configuration file failed during a post-processing step.")); Assert.IsTrue(error.Contains("An item with the same key has already been added.")); } + + /// + /// Test validates that when child files are present all autoentities are loaded correctly. + /// + [DataTestMethod] + [DataRow("Multidab-config.CosmosDb_NoSql.json", new string[] { "Multidab-config.MsSql.json", "Multidab-config.MySql.json", "Multidab-config.PostgreSql.json" }, 36)] + public async Task CanLoadValidMultiSourceConfigWithAutoentities(string configPath, IEnumerable dataSourceFiles, int expectedEntities) + { + string fileContents = await File.ReadAllTextAsync(configPath); + + // Parse the base JSON string + JObject baseJsonObject = JObject.Parse(fileContents); + + // Create a new JArray to hold the values to be appended + JArray valuesToAppend = new(dataSourceFiles); + + // Add or append the values to the base JSON + baseJsonObject.Add("data-source-files", valuesToAppend); + + // Convert the modified JSON object back to a JSON string + string resultJson = baseJsonObject.ToString(); + + IFileSystem fs = new MockFileSystem(new Dictionary() { { "dab-config.json", new MockFileData(resultJson) } }); + + FileSystemRuntimeConfigLoader loader = new(fs); + + Assert.IsTrue(loader.TryLoadConfig("dab-config.json", out RuntimeConfig runtimeConfig), "Should successfully load config"); + Assert.IsTrue(runtimeConfig.SqlDataSourceUsed, "Should have Sql data source"); + Assert.AreEqual(expectedEntities, runtimeConfig.Entities.Entities.Count, "Number of entities is not what is expected."); + } } diff --git a/src/Service.Tests/Configuration/Telemetry/AzureLogAnalyticsTests.cs b/src/Service.Tests/Configuration/Telemetry/AzureLogAnalyticsTests.cs index db6b58681b..dc35b1c4dd 100644 --- a/src/Service.Tests/Configuration/Telemetry/AzureLogAnalyticsTests.cs +++ b/src/Service.Tests/Configuration/Telemetry/AzureLogAnalyticsTests.cs @@ -120,9 +120,18 @@ public async Task TestAzureLogAnalyticsFlushServiceSucceed(string message, LogLe _ = Task.Run(() => flusherService.StartAsync(tokenSource.Token)); - await Task.Delay(2000); + // Poll until the log appears (the flusher service needs time to dequeue and upload) + int maxWaitMs = 10000; + int pollIntervalMs = 100; + int elapsed = 0; + while (customClient.LogAnalyticsLogs.Count == 0 && elapsed < maxWaitMs) + { + await Task.Delay(pollIntervalMs); + elapsed += pollIntervalMs; + } // Assert + Assert.IsTrue(customClient.LogAnalyticsLogs.Count > 0, $"Expected at least one log entry after waiting {elapsed}ms, but found none."); AzureLogAnalyticsLogs actualLog = customClient.LogAnalyticsLogs[0]; Assert.AreEqual(logLevel.ToString(), actualLog.LogLevel); Assert.AreEqual(message, actualLog.Message); diff --git a/src/Service.Tests/CosmosTests/QueryFilterTests.cs b/src/Service.Tests/CosmosTests/QueryFilterTests.cs index 187b447973..636988331c 100644 --- a/src/Service.Tests/CosmosTests/QueryFilterTests.cs +++ b/src/Service.Tests/CosmosTests/QueryFilterTests.cs @@ -899,6 +899,31 @@ public async Task TestFilterWithEntityNameAlias() await ExecuteAndValidateResult(_graphQLQueryName, gqlQuery, dbQuery); } + /// + /// Test filters on two different nested objects simultaneously + /// + [TestMethod] + public async Task TestFilterOnTwoDifferentNestedObjects() + { + string gqlQuery = @"{ + planets(first: 10, " + QueryBuilder.FILTER_FIELD_NAME + @" : { + character: { name: { eq: ""planet character"" } }, + earth: { type: { eq: ""earth4"" } } + }) + { + items { + id + name + } + } + }"; + + string dbQuery = "SELECT c.id, c.name FROM c " + + "WHERE c.character.name = \"planet character\" AND c.earth.type = \"earth4\""; + + await ExecuteAndValidateResult(_graphQLQueryName, gqlQuery, dbQuery); + } + /// /// For "item-level-permission-role" role, DB policies are defined. This test confirms that all the DB policies are considered. /// For the reference, Below conditions are applied for an Entity in Db Config file. diff --git a/src/Service.Tests/CosmosTests/TestBase.cs b/src/Service.Tests/CosmosTests/TestBase.cs index 8617534776..dfac601023 100644 --- a/src/Service.Tests/CosmosTests/TestBase.cs +++ b/src/Service.Tests/CosmosTests/TestBase.cs @@ -21,6 +21,7 @@ using Microsoft.AspNetCore.TestHost; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Newtonsoft.Json.Linq; @@ -158,7 +159,10 @@ protected WebApplicationFactory SetupTestApplicationFactory() FileSystemRuntimeConfigLoader loader = new(fileSystem); RuntimeConfigProvider provider = new(loader); - ISqlMetadataProvider cosmosSqlMetadataProvider = new CosmosSqlMetadataProvider(provider, fileSystem); + Mock> loggerValidator = new(); + RuntimeConfigValidator validator = new(provider, fileSystem, loggerValidator.Object); + + ISqlMetadataProvider cosmosSqlMetadataProvider = new CosmosSqlMetadataProvider(provider, validator, fileSystem); Mock metadataProviderFactory = new(); metadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(cosmosSqlMetadataProvider); diff --git a/src/Service.Tests/GraphQLBuilder/MultiSourceBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/MultiSourceBuilderTests.cs index 7c40dafcd6..67bfa8d2fe 100644 --- a/src/Service.Tests/GraphQLBuilder/MultiSourceBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/MultiSourceBuilderTests.cs @@ -46,11 +46,14 @@ public async Task CosmosSchemaBuilderTestAsync() RuntimeConfigProvider provider = new(loader); + Mock> loggerValidator = new(); + RuntimeConfigValidator validator = new(provider, fs, loggerValidator.Object); + Mock queryManagerfactory = new(); Mock queryEngineFactory = new(); Mock mutationEngineFactory = new(); Mock> logger = new(); - IMetadataProviderFactory metadataProviderFactory = new MetadataProviderFactory(provider, queryManagerfactory.Object, logger.Object, fs, handler: null); + IMetadataProviderFactory metadataProviderFactory = new MetadataProviderFactory(provider, validator, queryManagerfactory.Object, logger.Object, fs, handler: null); Mock authResolver = new(); GraphQLSchemaCreator creator = new(provider, queryEngineFactory.Object, mutationEngineFactory.Object, metadataProviderFactory, authResolver.Object); diff --git a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs index 94665d7c18..ffe636f6db 100644 --- a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Collections.Generic; +using System.IO.Abstractions; using System.Linq; using System.Threading.Tasks; using Azure.DataApiBuilder.Auth; @@ -392,6 +393,16 @@ private static async Task GetGQLSchemaCreator(RuntimeConfi Mock cache = new(); DabCacheService cacheService = new(cache: cache.Object, logger: null, httpContextAccessor: httpContextAccessor.Object); + // Setup runtime config validator + IFileSystem fileSystem = new FileSystem(); + FileSystemRuntimeConfigLoader configLoader = new(fileSystem) + { + RuntimeConfig = _runtimeConfig + }; + RuntimeConfigProvider configProvider = new(configLoader); + Mock> loggerValidator = new(); + RuntimeConfigValidator configValidator = new(configProvider, fileSystem, loggerValidator.Object); + // Setup query manager factory. IAbstractQueryManagerFactory queryManagerfactory = new QueryManagerFactory( runtimeConfigProvider: runtimeConfigProvider, @@ -402,6 +413,7 @@ private static async Task GetGQLSchemaCreator(RuntimeConfi // Setup metadata provider factory. IMetadataProviderFactory metadataProviderFactory = new MetadataProviderFactory( runtimeConfigProvider: runtimeConfigProvider, + runtimeConfigValidator: configValidator, queryManagerFactory: queryManagerfactory, logger: metadatProviderLogger.Object, fileSystem: null, diff --git a/src/Service.Tests/Mcp/DescribeEntitiesFilteringTests.cs b/src/Service.Tests/Mcp/DescribeEntitiesFilteringTests.cs new file mode 100644 index 0000000000..3defc34bba --- /dev/null +++ b/src/Service.Tests/Mcp/DescribeEntitiesFilteringTests.cs @@ -0,0 +1,504 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for DescribeEntitiesTool filtering logic (GitHub issue #3043). + /// Validates that entities with dml-tools: false are filtered from describe_entities, + /// regardless of entity type (tables, views, stored procedures). + /// When dml-tools is disabled, entities are not exposed via DML tools and should not appear in describe_entities. + /// + [TestClass] + public class DescribeEntitiesFilteringTests + { + /// + /// Verifies that when ALL entities have dml-tools: false, + /// describe_entities returns an AllEntitiesFilteredDmlDisabled error with guidance. + /// This ensures users understand why describe_entities is empty. + /// + [TestMethod] + public async Task DescribeEntities_AllEntitiesFilteredWhenDmlToolsDisabled() + { + // Arrange + RuntimeConfig config = CreateConfigWithCustomToolSP(); + IServiceProvider serviceProvider = CreateServiceProvider(config); + DescribeEntitiesTool tool = new(); + + // Act + CallToolResult result = await tool.ExecuteAsync(null, serviceProvider, CancellationToken.None); + + // Assert + AssertErrorResult(result, "AllEntitiesFilteredDmlDisabled"); + + // Verify the error message is helpful + JsonElement content = GetContentFromResult(result); + content.TryGetProperty("error", out JsonElement error); + Assert.IsTrue(error.TryGetProperty("message", out JsonElement errorMessage)); + string message = errorMessage.GetString() ?? string.Empty; + Assert.IsTrue(message.Contains("DML tools disabled") || message.Contains("dml-tools")); + Assert.IsTrue(message.Contains("tools/list") || message.Contains("custom-tool")); + } + + /// + /// Verifies that stored procedures with dml-tools enabled (or default) appear in describe_entities, + /// while stored procedures with dml-tools: false are filtered out. + /// This ensures filtering is based on dml-tools configuration. + /// + [TestMethod] + public async Task DescribeEntities_IncludesRegularStoredProcedures() + { + // Arrange + RuntimeConfig config = CreateConfigWithMixedStoredProcedures(); + + // Act & Assert + CallToolResult result = await ExecuteToolAsync(config); + AssertSuccessResultWithEntityNames(result, new[] { "CountBooks" }, new[] { "GetBook" }); + } + + /// + /// Verifies that tables and views with default/enabled dml-tools appear in describe_entities, + /// while stored procedures with dml-tools: false are filtered out. + /// This ensures filtering applies based on the dml-tools setting, not entity type. + /// + [TestMethod] + public async Task DescribeEntities_IncludesTablesAndViewsWithDmlToolsEnabled() + { + // Arrange & Act & Assert + RuntimeConfig config = CreateConfigWithMixedEntityTypes(); + CallToolResult result = await ExecuteToolAsync(config); + AssertSuccessResultWithEntityNames(result, new[] { "Book", "BookView" }, new[] { "GetBook" }); + } + + /// + /// Verifies that the 'count' field in describe_entities response accurately reflects + /// the number of entities AFTER filtering (excludes entities with dml-tools: false). + /// This ensures count matches the actual entities array length. + /// + [TestMethod] + public async Task DescribeEntities_CountReflectsFilteredList() + { + // Arrange + RuntimeConfig config = CreateConfigWithMixedEntityTypes(); + + // Act + CallToolResult result = await ExecuteToolAsync(config); + + // Assert + Assert.IsTrue(result.IsError == false || result.IsError == null); + JsonElement content = GetContentFromResult(result); + Assert.IsTrue(content.TryGetProperty("entities", out JsonElement entities)); + Assert.IsTrue(content.TryGetProperty("count", out JsonElement countElement)); + + int entityCount = entities.GetArrayLength(); + Assert.AreEqual(2, entityCount, "Config has 3 entities but only 2 should be returned (entity with dml-tools:false excluded)"); + Assert.AreEqual(entityCount, countElement.GetInt32(), "Count field should match filtered entity array length"); + } + + /// + /// Verifies that dml-tools filtering is applied consistently regardless of the nameOnly parameter. + /// When nameOnly=true (lightweight response), entities with dml-tools: false are still filtered out. + /// This ensures filtering behavior is consistent across both response modes. + /// + [TestMethod] + public async Task DescribeEntities_NameOnlyWorksWithFiltering() + { + // Arrange + RuntimeConfig config = CreateConfigWithMixedEntityTypes(); + IServiceProvider serviceProvider = CreateServiceProvider(config); + DescribeEntitiesTool tool = new(); + JsonDocument arguments = JsonDocument.Parse("{\"nameOnly\": true}"); + + // Act + CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, CancellationToken.None); + + // Assert + AssertSuccessResultWithEntityNames(result, new[] { "Book", "BookView" }, new[] { "GetBook" }); + } + + /// + /// Test that NoEntitiesConfigured error is returned when runtime config truly has no entities. + /// This is different from AllEntitiesFilteredDmlDisabled where entities exist but are filtered. + /// + [TestMethod] + public async Task DescribeEntities_ReturnsNoEntitiesConfigured_WhenConfigHasNoEntities() + { + // Arrange & Act + RuntimeConfig config = CreateConfigWithNoEntities(); + CallToolResult result = await ExecuteToolAsync(config); + + // Assert + AssertErrorResult(result, "NoEntitiesConfigured"); + + // Verify the error message indicates no entities configured + JsonElement content = GetContentFromResult(result); + content.TryGetProperty("error", out JsonElement error); + Assert.IsTrue(error.TryGetProperty("message", out JsonElement errorMessage)); + string message = errorMessage.GetString() ?? string.Empty; + Assert.IsTrue(message.Contains("No entities are configured")); + } + + /// + /// CRITICAL TEST: Verifies that stored procedures with BOTH custom-tool AND dml-tools enabled + /// appear in describe_entities. This validates the truth table scenario: + /// custom-tool: true, dml-tools: true → ✔ describe_entities + ✔ tools/list + /// + /// This test ensures the filtering logic only filters when dml-tools is FALSE, + /// not just when custom-tool is TRUE. + /// + [TestMethod] + public async Task DescribeEntities_IncludesCustomToolWithDmlEnabled() + { + // Arrange & Act + RuntimeConfig config = CreateConfigWithCustomToolAndDmlEnabled(); + CallToolResult result = await ExecuteToolAsync(config); + + // Assert + AssertSuccessResultWithEntityNames(result, new[] { "GetBook" }, Array.Empty()); + } + + /// + /// Verifies that when some (but not all) entities have dml-tools: false, + /// only non-filtered entities appear in the response. + /// This validates partial filtering works correctly with accurate count. + /// + [TestMethod] + public async Task DescribeEntities_ReturnsOnlyNonFilteredEntities_WhenPartiallyFiltered() + { + // Arrange & Act + RuntimeConfig config = CreateConfigWithMixedEntityTypes(); + CallToolResult result = await ExecuteToolAsync(config); + + // Assert + AssertSuccessResultWithEntityNames(result, new[] { "Book", "BookView" }, new[] { "GetBook" }); + + // Verify count matches + JsonElement content = GetContentFromResult(result); + Assert.IsTrue(content.TryGetProperty("count", out JsonElement countElement)); + Assert.AreEqual(2, countElement.GetInt32()); + } + + /// + /// Verifies that entities with DML tools disabled (dml-tools: false) are filtered from describe_entities. + /// This ensures the filtering applies to all entity types, not just stored procedures. + /// + [DataTestMethod] + [DataRow(EntitySourceType.Table, "Publisher", "Book", DisplayName = "Filters Table with DML disabled")] + [DataRow(EntitySourceType.View, "Book", "BookView", DisplayName = "Filters View with DML disabled")] + public async Task DescribeEntities_FiltersEntityWithDmlToolsDisabled(EntitySourceType filteredEntityType, string includedEntityName, string filteredEntityName) + { + // Arrange + RuntimeConfig config = CreateConfigWithEntityDmlDisabled(filteredEntityType, includedEntityName, filteredEntityName); + IServiceProvider serviceProvider = CreateServiceProvider(config); + DescribeEntitiesTool tool = new(); + + // Act + CallToolResult result = await tool.ExecuteAsync(null, serviceProvider, CancellationToken.None); + + // Assert + AssertSuccessResultWithEntityNames(result, new[] { includedEntityName }, new[] { filteredEntityName }); + } + + /// + /// Verifies that when ALL entities have dml-tools disabled, the appropriate error is returned. + /// This tests the error scenario applies to all entity types, not just stored procedures. + /// + [TestMethod] + public async Task DescribeEntities_ReturnsAllEntitiesFilteredDmlDisabled_WhenAllEntitiesHaveDmlDisabled() + { + // Arrange & Act + RuntimeConfig config = CreateConfigWithAllEntitiesDmlDisabled(); + CallToolResult result = await ExecuteToolAsync(config); + + // Assert + AssertErrorResult(result, "AllEntitiesFilteredDmlDisabled"); + + // Verify the error message is helpful + JsonElement content = GetContentFromResult(result); + content.TryGetProperty("error", out JsonElement error); + Assert.IsTrue(error.TryGetProperty("message", out JsonElement errorMessage)); + string message = errorMessage.GetString() ?? string.Empty; + Assert.IsTrue(message.Contains("DML tools disabled"), "Error message should mention DML tools disabled"); + Assert.IsTrue(message.Contains("dml-tools: false"), "Error message should mention the config syntax"); + } + + #region Helper Methods + + /// + /// Executes the DescribeEntitiesTool with the given config. + /// + private static async Task ExecuteToolAsync(RuntimeConfig config, JsonDocument arguments = null) + { + IServiceProvider serviceProvider = CreateServiceProvider(config); + DescribeEntitiesTool tool = new(); + return await tool.ExecuteAsync(arguments, serviceProvider, CancellationToken.None); + } + + /// + /// Runs the DescribeEntitiesTool and asserts successful execution with expected entity names. + /// + private static void AssertSuccessResultWithEntityNames(CallToolResult result, string[] includedEntities, string[] excludedEntities) + { + Assert.IsTrue(result.IsError == false || result.IsError == null); + JsonElement content = GetContentFromResult(result); + Assert.IsTrue(content.TryGetProperty("entities", out JsonElement entities)); + + List entityNames = entities.EnumerateArray() + .Select(e => e.GetProperty("name").GetString()!) + .ToList(); + + foreach (string includedEntity in includedEntities) + { + Assert.IsTrue(entityNames.Contains(includedEntity), $"{includedEntity} should be included"); + } + + foreach (string excludedEntity in excludedEntities) + { + Assert.IsFalse(entityNames.Contains(excludedEntity), $"{excludedEntity} should be excluded"); + } + + Assert.AreEqual(includedEntities.Length, entities.GetArrayLength()); + } + + /// + /// Asserts that the result contains an error with the specified type. + /// + private static void AssertErrorResult(CallToolResult result, string expectedErrorType) + { + Assert.IsTrue(result.IsError == true); + JsonElement content = GetContentFromResult(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.IsTrue(error.TryGetProperty("type", out JsonElement errorType)); + Assert.AreEqual(expectedErrorType, errorType.GetString()); + } + + /// + /// Creates a basic entity with standard permissions. + /// + private static Entity CreateEntity(string sourceName, EntitySourceType sourceType, string singularName, string pluralName, EntityMcpOptions mcpOptions = null) + { + EntityActionOperation action = sourceType == EntitySourceType.StoredProcedure + ? EntityActionOperation.Execute + : EntityActionOperation.Read; + + return new Entity( + Source: new(sourceName, sourceType, null, null), + GraphQL: new(singularName, pluralName), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { new EntityAction(Action: action, Fields: null, Policy: null) }) }, + Mappings: null, + Relationships: null, + Mcp: mcpOptions + ); + } + + /// + /// Creates a runtime config with the specified entities. + /// + private static RuntimeConfig CreateRuntimeConfig(Dictionary entities) + { + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(Enabled: true, Path: "/mcp", DmlTools: null), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + /// + /// Creates a runtime config with a stored procedure that has dml-tools: false. + /// Used to test the AllEntitiesFilteredDmlDisabled error scenario. + /// + private static RuntimeConfig CreateConfigWithCustomToolSP() + { + Dictionary entities = new() + { + ["GetBook"] = CreateEntity("get_book", EntitySourceType.StoredProcedure, "GetBook", "GetBook", + new EntityMcpOptions(customToolEnabled: true, dmlToolsEnabled: false)) + }; + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a runtime config with mixed stored procedures: + /// one SP with dml-tools enabled/default (CountBooks) and one with dml-tools: false (GetBook). + /// Used to test that filtering is based on dml-tools configuration. + /// + private static RuntimeConfig CreateConfigWithMixedStoredProcedures() + { + Dictionary entities = new() + { + ["CountBooks"] = CreateEntity("count_books", EntitySourceType.StoredProcedure, "CountBooks", "CountBooks"), + ["GetBook"] = CreateEntity("get_book", EntitySourceType.StoredProcedure, "GetBook", "GetBook", + new EntityMcpOptions(customToolEnabled: true, dmlToolsEnabled: false)) + }; + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a runtime config with mixed entity types: + /// table (Book), view (BookView), and SP with dml-tools: false (GetBook). + /// Used to test that filtering applies to all entity types based on dml-tools setting. + /// + private static RuntimeConfig CreateConfigWithMixedEntityTypes() + { + Dictionary entities = new() + { + ["Book"] = CreateEntity("books", EntitySourceType.Table, "Book", "Books"), + ["BookView"] = CreateEntity("book_view", EntitySourceType.View, "BookView", "BookViews"), + ["GetBook"] = CreateEntity("get_book", EntitySourceType.StoredProcedure, "GetBook", "GetBook", + new EntityMcpOptions(customToolEnabled: true, dmlToolsEnabled: false)) + }; + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a runtime config with an empty entities dictionary. + /// Used to test the NoEntitiesConfigured error when no entities are configured at all. + /// + private static RuntimeConfig CreateConfigWithNoEntities() + { + return CreateRuntimeConfig(new Dictionary()); + } + + /// + /// Creates a runtime config with a stored procedure that has BOTH custom-tool and dml-tools enabled. + /// Used to test the truth table scenario: custom-tool:true + dml-tools:true → should appear in describe_entities. + /// + private static RuntimeConfig CreateConfigWithCustomToolAndDmlEnabled() + { + Dictionary entities = new() + { + ["GetBook"] = CreateEntity("get_book", EntitySourceType.StoredProcedure, "GetBook", "GetBook", + new EntityMcpOptions(customToolEnabled: true, dmlToolsEnabled: true)) + }; + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a runtime config with an entity that has dml-tools disabled. + /// Used to test that entities with dml-tools: false are filtered from describe_entities. + /// + private static RuntimeConfig CreateConfigWithEntityDmlDisabled(EntitySourceType filteredEntityType, string includedEntityName, string filteredEntityName) + { + Dictionary entities = new(); + + // Add the included entity (different type based on what's being filtered) + if (filteredEntityType == EntitySourceType.Table) + { + entities[includedEntityName] = CreateEntity("publishers", EntitySourceType.Table, includedEntityName, $"{includedEntityName}s", + new EntityMcpOptions(customToolEnabled: null, dmlToolsEnabled: true)); + entities[filteredEntityName] = CreateEntity("books", EntitySourceType.Table, filteredEntityName, $"{filteredEntityName}s", + new EntityMcpOptions(customToolEnabled: null, dmlToolsEnabled: false)); + } + else if (filteredEntityType == EntitySourceType.View) + { + entities[includedEntityName] = CreateEntity("books", EntitySourceType.Table, includedEntityName, $"{includedEntityName}s"); + entities[filteredEntityName] = CreateEntity("book_view", EntitySourceType.View, filteredEntityName, $"{filteredEntityName}s", + new EntityMcpOptions(customToolEnabled: null, dmlToolsEnabled: false)); + } + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a runtime config where all entities have dml-tools disabled. + /// Used to test the AllEntitiesFilteredDmlDisabled error scenario. + /// + private static RuntimeConfig CreateConfigWithAllEntitiesDmlDisabled() + { + Dictionary entities = new() + { + ["Book"] = CreateEntity("books", EntitySourceType.Table, "Book", "Books", + new EntityMcpOptions(customToolEnabled: null, dmlToolsEnabled: false)), + ["BookView"] = CreateEntity("book_view", EntitySourceType.View, "BookView", "BookViews", + new EntityMcpOptions(customToolEnabled: null, dmlToolsEnabled: false)), + ["GetBook"] = CreateEntity("get_book", EntitySourceType.StoredProcedure, "GetBook", "GetBook", + new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: false)) + }; + + return CreateRuntimeConfig(entities); + } + + /// + /// Creates a service provider with mocked dependencies for testing DescribeEntitiesTool. + /// Configures anonymous role and necessary DAB services. + /// + private static IServiceProvider CreateServiceProvider(RuntimeConfig config) + { + ServiceCollection services = new(); + + // Use shared test helper to create RuntimeConfigProvider + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + + // Mock IAuthorizationResolver + Mock mockAuthResolver = new(); + mockAuthResolver.Setup(x => x.IsValidRoleContext(It.IsAny())).Returns(true); + services.AddSingleton(mockAuthResolver.Object); + + // Mock HttpContext with anonymous role + Mock mockHttpContext = new(); + Mock mockRequest = new(); + mockRequest.Setup(x => x.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]).Returns("anonymous"); + mockHttpContext.Setup(x => x.Request).Returns(mockRequest.Object); + + Mock mockHttpContextAccessor = new(); + mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object); + services.AddSingleton(mockHttpContextAccessor.Object); + + // Add logging + services.AddLogging(); + + return services.BuildServiceProvider(); + } + + /// + /// Extracts and parses the JSON content from an MCP tool call result. + /// Returns the root JsonElement for assertion purposes. + /// + private static JsonElement GetContentFromResult(CallToolResult result) + { + Assert.IsNotNull(result.Content); + Assert.IsTrue(result.Content.Count > 0); + + // Verify the content block is the expected type before casting + Assert.IsInstanceOfType(result.Content[0], typeof(TextContentBlock), + "Expected first content block to be TextContentBlock"); + + TextContentBlock firstContent = (TextContentBlock)result.Content[0]; + Assert.IsNotNull(firstContent.Text); + + return JsonDocument.Parse(firstContent.Text).RootElement; + } + + #endregion + } +} diff --git a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs index d2f6554cd3..278bc95cfb 100644 --- a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs +++ b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs @@ -7,9 +7,12 @@ using System.Threading; using System.Threading.Tasks; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Azure.DataApiBuilder.Mcp.Model; using Microsoft.AspNetCore.Http; @@ -188,8 +191,74 @@ public async Task DynamicCustomTool_RespectsCustomToolDisabled() AssertToolDisabledError(content, "Custom tool is disabled for entity 'GetBook'"); } + #region View Support Tests + + /// + /// Data-driven test to verify all DML tools allow both table and view entities. + /// This is critical for scenarios like vector data type support, where users must: + /// - Create a view that omits unsupported columns (e.g., vector columns) + /// - Perform DML operations against that view + /// + /// The tool type to test. + /// The entity source type (Table or View). + /// The entity name to use. + /// The JSON arguments for the tool. + [DataTestMethod] + [DataRow("CreateRecord", "Table", "{\"entity\": \"Book\", \"data\": {\"id\": 1, \"title\": \"Test\"}}", DisplayName = "CreateRecord allows Table")] + [DataRow("CreateRecord", "View", "{\"entity\": \"BookView\", \"data\": {\"id\": 1, \"title\": \"Test\"}}", DisplayName = "CreateRecord allows View")] + [DataRow("ReadRecords", "Table", "{\"entity\": \"Book\"}", DisplayName = "ReadRecords allows Table")] + [DataRow("ReadRecords", "View", "{\"entity\": \"BookView\"}", DisplayName = "ReadRecords allows View")] + [DataRow("UpdateRecord", "Table", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", DisplayName = "UpdateRecord allows Table")] + [DataRow("UpdateRecord", "View", "{\"entity\": \"BookView\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", DisplayName = "UpdateRecord allows View")] + [DataRow("DeleteRecord", "Table", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}}", DisplayName = "DeleteRecord allows Table")] + [DataRow("DeleteRecord", "View", "{\"entity\": \"BookView\", \"keys\": {\"id\": 1}}", DisplayName = "DeleteRecord allows View")] + public async Task DmlTool_AllowsTablesAndViews(string toolType, string sourceType, string jsonArguments) + { + // Arrange + RuntimeConfig config = sourceType == "View" + ? CreateConfigWithViewEntity() + : CreateConfigWithDmlToolEnabledEntity(); + IServiceProvider serviceProvider = CreateServiceProvider(config); + IMcpTool tool = CreateTool(toolType); + + JsonDocument arguments = JsonDocument.Parse(jsonArguments); + + // Act + CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, CancellationToken.None); + + // Assert - Should NOT be a source type blocking error (InvalidEntity) + // Other errors like missing metadata are acceptable since we're testing source type validation + if (result.IsError == true) + { + JsonElement content = ParseResultContent(result); + + if (content.TryGetProperty("error", out JsonElement error) && + error.TryGetProperty("type", out JsonElement errorType)) + { + string errorTypeValue = errorType.GetString() ?? string.Empty; + + // This error type indicates the tool is blocking based on source type + Assert.AreNotEqual("InvalidEntity", errorTypeValue, + $"{sourceType} entities should not be blocked with InvalidEntity"); + } + } + } + + #endregion + #region Helper Methods + /// + /// Helper method to parse the JSON content from a CallToolResult without re-executing the tool. + /// + /// The result from executing an MCP tool. + /// The parsed JsonElement from the result's content. + private static JsonElement ParseResultContent(CallToolResult result) + { + TextContentBlock firstContent = (TextContentBlock)result.Content[0]; + return JsonDocument.Parse(firstContent.Text).RootElement; + } + /// /// Helper method to execute an MCP tool and return the parsed JsonElement from the result. /// @@ -200,8 +269,7 @@ public async Task DynamicCustomTool_RespectsCustomToolDisabled() private static async Task RunToolAsync(IMcpTool tool, JsonDocument arguments, IServiceProvider serviceProvider) { CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, CancellationToken.None); - TextContentBlock firstContent = (TextContentBlock)result.Content[0]; - return JsonDocument.Parse(firstContent.Text).RootElement; + return ParseResultContent(result); } /// @@ -517,8 +585,63 @@ private static RuntimeConfig CreateConfigWithRuntimeDisabledButEntityEnabled() ); } + /// + /// Creates a runtime config with a view entity. + /// This is the key scenario for vector data type support. + /// + private static RuntimeConfig CreateConfigWithViewEntity() + { + Dictionary entities = new() + { + ["BookView"] = new Entity( + Source: new EntitySource( + Object: "dbo.vBooks", + Type: EntitySourceType.View, + Parameters: null, + KeyFields: new[] { "id" } + ), + GraphQL: new("BookView", "BookViews"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null), + new EntityAction(Action: EntityActionOperation.Create, Fields: null, Policy: null), + new EntityAction(Action: EntityActionOperation.Update, Fields: null, Policy: null), + new EntityAction(Action: EntityActionOperation.Delete, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: true) + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + /// /// Creates a service provider with mocked dependencies for testing MCP tools. + /// Includes metadata provider mocks so tests can reach source type validation. /// private static IServiceProvider CreateServiceProvider(RuntimeConfig config) { @@ -540,6 +663,54 @@ private static IServiceProvider CreateServiceProvider(RuntimeConfig config) mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object); services.AddSingleton(mockHttpContextAccessor.Object); + // Add metadata provider mocks so tests can reach source type validation. + // This is required for DmlTool_AllowsTablesAndViews to actually test the source type behavior. + Mock mockSqlMetadataProvider = new(); + Dictionary entityToDatabaseObject = new(); + + // Add database objects for each entity in the config + if (config.Entities != null) + { + foreach (KeyValuePair kvp in config.Entities) + { + string entityName = kvp.Key; + Entity entity = kvp.Value; + EntitySourceType sourceType = entity.Source.Type ?? EntitySourceType.Table; + + DatabaseObject dbObject; + if (sourceType == EntitySourceType.View) + { + dbObject = new DatabaseView("dbo", entity.Source.Object) + { + SourceType = EntitySourceType.View + }; + } + else if (sourceType == EntitySourceType.StoredProcedure) + { + dbObject = new DatabaseStoredProcedure("dbo", entity.Source.Object) + { + SourceType = EntitySourceType.StoredProcedure + }; + } + else + { + dbObject = new DatabaseTable("dbo", entity.Source.Object) + { + SourceType = EntitySourceType.Table + }; + } + + entityToDatabaseObject[entityName] = dbObject; + } + } + + mockSqlMetadataProvider.Setup(x => x.EntityToDatabaseObject).Returns(entityToDatabaseObject); + mockSqlMetadataProvider.Setup(x => x.GetDatabaseType()).Returns(DatabaseType.MSSQL); + + Mock mockMetadataProviderFactory = new(); + mockMetadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(mockSqlMetadataProvider.Object); + services.AddSingleton(mockMetadataProviderFactory.Object); + services.AddLogging(); return services.BuildServiceProvider(); diff --git a/src/Service.Tests/Mcp/McpToolRegistryTests.cs b/src/Service.Tests/Mcp/McpToolRegistryTests.cs new file mode 100644 index 0000000000..7bbd91341c --- /dev/null +++ b/src/Service.Tests/Mcp/McpToolRegistryTests.cs @@ -0,0 +1,337 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Mcp.Core; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for McpToolRegistry to ensure tool name uniqueness validation. + /// + [TestClass] + public class McpToolRegistryTests + { + /// + /// Test that registering multiple tools with unique names succeeds. + /// + [TestMethod] + public void RegisterTool_WithMultipleUniqueNames_Succeeds() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool1 = new MockMcpTool("tool_one", ToolType.BuiltIn); + IMcpTool tool2 = new MockMcpTool("tool_two", ToolType.Custom); + IMcpTool tool3 = new MockMcpTool("tool_three", ToolType.BuiltIn); + + // Act & Assert - should not throw + registry.RegisterTool(tool1); + registry.RegisterTool(tool2); + registry.RegisterTool(tool3); + + // Verify all tools were registered + IEnumerable allTools = registry.GetAllTools(); + Assert.AreEqual(3, allTools.Count()); + } + + /// + /// Test that registering duplicate tools of the same type throws an exception. + /// Validates that both built-in and custom tools enforce name uniqueness within their own type. + /// + [DataTestMethod] + [DataRow(ToolType.BuiltIn, "duplicate_tool", "built-in", DisplayName = "Duplicate Built-In Tools")] + [DataRow(ToolType.Custom, "my_custom_tool", "custom", DisplayName = "Duplicate Custom Tools")] + public void RegisterTool_WithDuplicateSameType_ThrowsException( + ToolType toolType, + string toolName, + string expectedToolTypeText) + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool1 = new MockMcpTool(toolName, toolType); + IMcpTool tool2 = new MockMcpTool(toolName, toolType); + + // Act - Register first tool + registry.RegisterTool(tool1); + + // Assert - Second registration should throw + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(tool2) + ); + + // Verify exception details + Assert.IsTrue(exception.Message.Contains($"Duplicate MCP tool name '{toolName}' detected")); + Assert.IsTrue(exception.Message.Contains($"{expectedToolTypeText} tool with this name is already registered")); + Assert.IsTrue(exception.Message.Contains($"Cannot register {expectedToolTypeText} tool with the same name")); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, exception.StatusCode); + } + + /// + /// Test that registering tools with conflicting names across different types throws an exception. + /// Validates that tool names must be unique across all tool types (built-in and custom). + /// + [DataTestMethod] + [DataRow("create_record", ToolType.BuiltIn, ToolType.Custom, "built-in", "custom", DisplayName = "Built-In then Custom conflict")] + [DataRow("read_records", ToolType.BuiltIn, ToolType.Custom, "built-in", "custom", DisplayName = "Built-In then Custom conflict (read_records)")] + [DataRow("my_stored_proc", ToolType.Custom, ToolType.BuiltIn, "custom", "built-in", DisplayName = "Custom then Built-In conflict")] + public void RegisterTool_WithCrossTypeConflict_ThrowsException( + string toolName, + ToolType firstToolType, + ToolType secondToolType, + string expectedExistingType, + string expectedNewType) + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool existingTool = new MockMcpTool(toolName, firstToolType); + IMcpTool conflictingTool = new MockMcpTool(toolName, secondToolType); + + // Act - Register first tool + registry.RegisterTool(existingTool); + + // Assert - Second tool registration should throw + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(conflictingTool) + ); + + // Verify exception details + Assert.IsTrue(exception.Message.Contains($"Duplicate MCP tool name '{toolName}' detected")); + Assert.IsTrue(exception.Message.Contains($"{expectedExistingType} tool with this name is already registered")); + Assert.IsTrue(exception.Message.Contains($"Cannot register {expectedNewType} tool with the same name")); + Assert.IsTrue(exception.Message.Contains("Tool names must be unique across all tool types")); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, exception.StatusCode); + } + + /// + /// Test that tool name comparison is case-sensitive. + /// Tools with different casing should not be allowed. + /// + [TestMethod] + public void RegisterTool_WithDifferentCasing_ThrowsException() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool1 = new MockMcpTool("my_tool", ToolType.BuiltIn); + IMcpTool tool2 = new MockMcpTool("My_Tool", ToolType.Custom); + + // Act - Register first tool + registry.RegisterTool(tool1); + + // Assert - Case-insensitive duplicate should throw + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(tool2) + ); + + Assert.IsTrue(exception.Message.Contains("Duplicate MCP tool name")); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); + } + + /// + /// Test that registering the same tool instance twice is silently ignored (idempotent). + /// This supports stdio mode where both McpToolRegistryInitializer and McpStdioHelper may register the same tools. + /// + [TestMethod] + public void RegisterTool_SameInstanceTwice_IsIdempotent() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool = new MockMcpTool("my_tool", ToolType.BuiltIn); + + // Act - Register the same instance twice + registry.RegisterTool(tool); + registry.RegisterTool(tool); + + // Assert - Tool should be registered only once + IEnumerable allTools = registry.GetAllTools(); + Assert.AreEqual(1, allTools.Count()); + } + + /// + /// Test that registering a different instance with the same name throws an exception, + /// even though a same-instance re-registration would be allowed. + /// + [TestMethod] + public void RegisterTool_DifferentInstanceSameName_ThrowsException() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool1 = new MockMcpTool("my_tool", ToolType.BuiltIn); + IMcpTool tool2 = new MockMcpTool("my_tool", ToolType.BuiltIn); + + // Act - Register first instance + registry.RegisterTool(tool1); + + // Assert - Different instance with same name should throw + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(tool2) + ); + + Assert.IsTrue(exception.Message.Contains("Duplicate MCP tool name 'my_tool' detected")); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); + } + + /// + /// Test that GetAllTools returns all registered tools. + /// + [TestMethod] + public void GetAllTools_ReturnsAllRegisteredTools() + { + // Arrange + McpToolRegistry registry = new(); + registry.RegisterTool(new MockMcpTool("tool_a", ToolType.BuiltIn)); + registry.RegisterTool(new MockMcpTool("tool_b", ToolType.Custom)); + registry.RegisterTool(new MockMcpTool("tool_c", ToolType.BuiltIn)); + + // Act + IEnumerable allTools = registry.GetAllTools(); + + // Assert + Assert.AreEqual(3, allTools.Count()); + Assert.IsTrue(allTools.Any(t => t.Name == "tool_a")); + Assert.IsTrue(allTools.Any(t => t.Name == "tool_b")); + Assert.IsTrue(allTools.Any(t => t.Name == "tool_c")); + } + + /// + /// Test that TryGetTool returns false for non-existent tool. + /// + [TestMethod] + public void TryGetTool_WithNonExistentName_ReturnsFalse() + { + // Arrange + McpToolRegistry registry = new(); + registry.RegisterTool(new MockMcpTool("existing_tool", ToolType.BuiltIn)); + + // Act + bool found = registry.TryGetTool("non_existent_tool", out IMcpTool? tool); + + // Assert + Assert.IsFalse(found); + Assert.IsNull(tool); + } + + /// + /// Test edge case: empty tool name should throw exception. + /// + [TestMethod] + public void RegisterTool_WithEmptyToolName_ThrowsException() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool = new MockMcpTool("", ToolType.BuiltIn); + + // Assert - Empty tool names should be rejected + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(tool) + ); + + Assert.IsTrue(exception.Message.Contains("cannot be null, empty, or whitespace")); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ErrorInInitialization, exception.SubStatusCode); + } + + /// + /// Test realistic scenario with actual built-in tool names. + /// + [TestMethod] + public void RegisterTool_WithRealisticBuiltInToolNames_DetectsDuplicates() + { + // Arrange + McpToolRegistry registry = new(); + + // Simulate registering built-in tools + registry.RegisterTool(new MockMcpTool("create_record", ToolType.BuiltIn)); + registry.RegisterTool(new MockMcpTool("read_records", ToolType.BuiltIn)); + registry.RegisterTool(new MockMcpTool("update_record", ToolType.BuiltIn)); + registry.RegisterTool(new MockMcpTool("delete_record", ToolType.BuiltIn)); + registry.RegisterTool(new MockMcpTool("describe_entities", ToolType.BuiltIn)); + + // Try to register a custom tool with a conflicting name + IMcpTool customTool = new MockMcpTool("read_records", ToolType.Custom); + + // Assert - Should throw + DataApiBuilderException exception = Assert.ThrowsException( + () => registry.RegisterTool(customTool) + ); + + Assert.IsTrue(exception.Message.Contains("read_records")); + Assert.IsTrue(exception.Message.Contains("built-in tool")); + } + + /// + /// Test that registering a tool with leading/trailing whitespace in the name is treated as a duplicate of the trimmed name. + /// Note: during tool registration, the registry should trim whitespace and detect duplicates accordingly. + /// + [TestMethod] + public void RegisterTool_WithLeadingTrailingWhitespace_DetectsDuplicate() + { + // Arrange + McpToolRegistry registry = new(); + IMcpTool tool1 = new MockMcpTool("my_tool", ToolType.BuiltIn); + IMcpTool tool2 = new MockMcpTool(" my_tool ", ToolType.Custom); + + // Act + registry.RegisterTool(tool1); + + // Assert - trimmed name should collide + Assert.ThrowsException( + () => registry.RegisterTool(tool2) + ); + } + + #region Private helpers + + /// + /// Mock implementation of IMcpTool for testing purposes. + /// + private class MockMcpTool : IMcpTool + { + private readonly string _toolName; + + public MockMcpTool(string toolName, ToolType toolType) + { + _toolName = toolName; + ToolType = toolType; + } + + public ToolType ToolType { get; } + + public Tool GetToolMetadata() + { + // Create a simple JSON object for the input schema + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = _toolName, + Description = $"Mock {ToolType} tool", + InputSchema = doc.RootElement.Clone() + }; + } + + public Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + // Not used in these tests + throw new NotImplementedException(); + } + } + + #endregion Private helpers + } +} diff --git a/src/Service.Tests/ModuleInitializer.cs b/src/Service.Tests/ModuleInitializer.cs index 89b7dbc3c4..b2d713b804 100644 --- a/src/Service.Tests/ModuleInitializer.cs +++ b/src/Service.Tests/ModuleInitializer.cs @@ -27,6 +27,8 @@ public static void Init() VerifierSettings.IgnoreMember(dataSource => dataSource.IsDatasourceHealthEnabled); // Ignore the DatasourceThresholdMs from the output to avoid committing it. VerifierSettings.IgnoreMember(dataSource => dataSource.DatasourceThresholdMs); + // Ignore the IsUserDelegatedAuthEnabled from the output as it's a computed property. + VerifierSettings.IgnoreMember(dataSource => dataSource.IsUserDelegatedAuthEnabled); // Ignore the JSON schema path as that's unimportant from a test standpoint. VerifierSettings.IgnoreMember(config => config.Schema); // Ignore the datasource files as that's unimportant from a test standpoint. @@ -103,6 +105,8 @@ public static void Init() VerifierSettings.IgnoreMember(options => options.EnableDwNto1JoinOpt); // Ignore the FeatureFlags as that's unimportant from a test standpoint. VerifierSettings.IgnoreMember(options => options.FeatureFlags); + // Ignore the JSON schema path as that's unimportant from a test standpoint. + VerifierSettings.IgnoreMember(config => config.Autoentities); // Ignore the message as that's not serialized in our config file anyway. VerifierSettings.IgnoreMember(dataSource => dataSource.DatabaseTypeNotSupportedMessage); // Ignore DefaultDataSourceName as that's not serialized in our config file. diff --git a/src/Service.Tests/Multidab-config.MsSql.json b/src/Service.Tests/Multidab-config.MsSql.json index b54b629023..68a8f0df6d 100644 --- a/src/Service.Tests/Multidab-config.MsSql.json +++ b/src/Service.Tests/Multidab-config.MsSql.json @@ -2,7 +2,7 @@ "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json", "data-source": { "database-type": "mssql", - "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=5;", + "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=30;", "options": { "set-session-context": true } @@ -1563,5 +1563,81 @@ } ] } + }, + "autoentities": { + "AutoPublisher": { + "patterns": { + "include": [ + "%publisher%" + ], + "exclude": [ + "%book%" + ], + "name": "auto{object}" + }, + "template": { + "rest": { + "enabled": true + }, + "graphql": { + "enabled": true + }, + "health": { + "enabled": false + }, + "cache": { + "enabled": false, + "ttl-seconds": 10, + "level": "l1l2" + } + }, + "permissions": [ + { + "role": "anonymous", + "actions": [ + { + "action": "*" + } + ] + } + ] + }, + "NewBooks": { + "patterns": { + "include": [ + "%book%" + ], + "exclude": [ + "%publisher%" + ], + "name": "{schema}_auto_{object}" + }, + "template": { + "rest": { + "enabled": true + }, + "graphql": { + "enabled": true + }, + "health": { + "enabled": true + }, + "cache": { + "enabled": true, + "ttl-seconds": 5, + "level": "l1l2" + } + }, + "permissions": [ + { + "role": "anonymous", + "actions": [ + { + "action": "read" + } + ] + } + ] + } } } diff --git a/src/Service.Tests/OpenApiDocumentor/DocumentVerbosityTests.cs b/src/Service.Tests/OpenApiDocumentor/DocumentVerbosityTests.cs index fa43617f4f..c21718c87b 100644 --- a/src/Service.Tests/OpenApiDocumentor/DocumentVerbosityTests.cs +++ b/src/Service.Tests/OpenApiDocumentor/DocumentVerbosityTests.cs @@ -23,7 +23,7 @@ public class DocumentVerbosityTests private const string UNEXPECTED_CONTENTS_ERROR = "Unexpected number of response objects to validate."; /// - /// Validates that for the Book entity, 7 response object schemas generated by OpenApiDocumentor + /// Validates that for the Book entity, 9 response object schemas generated by OpenApiDocumentor /// contain a 'type' property with value 'object'. /// /// Two paths: @@ -32,9 +32,9 @@ public class DocumentVerbosityTests /// - Validate responses that return result contents: /// GET (200), PUT (200, 201), PATCH (200, 201) /// - "/Books" - /// - 2 operations GET(all) POST + /// - 4 operations GET(all) POST PUT(keyless) PATCH(keyless) /// - Validate responses that return result contents: - /// GET (200), POST (201) + /// GET (200), POST (201), PUT keyless (201), PATCH keyless (201) /// [TestMethod] public async Task ResponseObjectSchemaIncludesTypeProperty() @@ -71,10 +71,10 @@ public async Task ResponseObjectSchemaIncludesTypeProperty() .Select(pair => pair.Value) .ToList(); - // Validate that 7 response object schemas contain a 'type' property with value 'object' - // Test summary describes all 7 expected responses. + // Validate that 9 response object schemas contain a 'type' property with value 'object' + // Test summary describes all 9 expected responses. Assert.IsTrue( - condition: responses.Count == 7, + condition: responses.Count == 9, message: UNEXPECTED_CONTENTS_ERROR); foreach (OpenApiResponse response in responses) diff --git a/src/Service.Tests/OpenApiDocumentor/FieldFilteringTests.cs b/src/Service.Tests/OpenApiDocumentor/FieldFilteringTests.cs new file mode 100644 index 0000000000..9e2aca4b9d --- /dev/null +++ b/src/Service.Tests/OpenApiDocumentor/FieldFilteringTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.OpenApi.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.OpenApiIntegration +{ + /// + /// Tests validating OpenAPI schema filters fields based on entity permissions. + /// + [TestCategory(TestCategory.MSSQL)] + [TestClass] + public class FieldFilteringTests + { + private const string CONFIG_FILE = "field-filter-config.MsSql.json"; + private const string DB_ENV = TestCategory.MSSQL; + + /// + /// Validates that excluded fields are not shown in OpenAPI schema. + /// + [TestMethod] + public async Task ExcludedFields_NotShownInSchema() + { + // Create permission with excluded field + EntityActionFields fields = new(Exclude: new HashSet { "publisher_id" }, Include: null); + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "anonymous", Actions: new[] { new EntityAction(EntityActionOperation.All, fields, new()) }) + }; + + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions); + + // Check that the excluded field is not in the schema + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book"), "Schema should exist for book entity"); + Assert.IsFalse(doc.Components.Schemas["book"].Properties.ContainsKey("publisher_id"), "Excluded field should not be in schema"); + } + + /// + /// Validates superset of fields across different role permissions is shown. + /// + [TestMethod] + public async Task MixedRoleFieldPermissions_ShowsSupersetOfFields() + { + // Anonymous can see id only, authenticated can see title only + EntityActionFields anonymousFields = new(Exclude: new HashSet(), Include: new HashSet { "id" }); + EntityActionFields authenticatedFields = new(Exclude: new HashSet(), Include: new HashSet { "title" }); + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "anonymous", Actions: new[] { new EntityAction(EntityActionOperation.Read, anonymousFields, new()) }), + new EntityPermission(Role: "authenticated", Actions: new[] { new EntityAction(EntityActionOperation.Read, authenticatedFields, new()) }) + }; + + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions); + + // Should have both id (from anonymous) and title (from authenticated) - superset of fields + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book"), "Schema should exist for book entity"); + Assert.IsTrue(doc.Components.Schemas["book"].Properties.ContainsKey("id"), "Field 'id' should be in schema from anonymous role"); + Assert.IsTrue(doc.Components.Schemas["book"].Properties.ContainsKey("title"), "Field 'title' should be in schema from authenticated role"); + } + + private static async Task GenerateDocumentWithPermissions(EntityPermission[] permissions) + { + Entity entity = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: permissions, + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary { { "book", entity } }); + return await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync(entities, CONFIG_FILE, DB_ENV); + } + } +} diff --git a/src/Service.Tests/OpenApiDocumentor/OpenApiTestBootstrap.cs b/src/Service.Tests/OpenApiDocumentor/OpenApiTestBootstrap.cs index a2440f728e..82b9390ab7 100644 --- a/src/Service.Tests/OpenApiDocumentor/OpenApiTestBootstrap.cs +++ b/src/Service.Tests/OpenApiDocumentor/OpenApiTestBootstrap.cs @@ -27,22 +27,32 @@ internal class OpenApiTestBootstrap /// /// /// + /// Optional value for request-body-strict setting. If null, uses default (true). + /// Optional role to filter OpenAPI document. If null, returns superset of all roles. /// Generated OpenApiDocument internal static async Task GenerateOpenApiDocumentAsync( RuntimeEntities runtimeEntities, string configFileName, - string databaseEnvironment) + string databaseEnvironment, + bool? requestBodyStrict = null, + string role = null) { TestHelper.SetupDatabaseEnvironment(databaseEnvironment); FileSystem fileSystem = new(); FileSystemRuntimeConfigLoader loader = new(fileSystem); loader.TryLoadKnownConfig(out RuntimeConfig config); + // Create Rest options with the specified request-body-strict setting + RestRuntimeOptions restOptions = requestBodyStrict.HasValue + ? config.Runtime?.Rest with { RequestBodyStrict = requestBodyStrict.Value } ?? new RestRuntimeOptions(RequestBodyStrict: requestBodyStrict.Value) + : config.Runtime?.Rest ?? new RestRuntimeOptions(); + RuntimeConfig configWithCustomHostMode = config with { Runtime = config.Runtime with { - Host = config.Runtime?.Host with { Mode = HostMode.Production } + Host = config.Runtime?.Host with { Mode = HostMode.Development }, + Rest = restOptions }, Entities = runtimeEntities }; @@ -56,7 +66,8 @@ internal static async Task GenerateOpenApiDocumentAsync( using TestServer server = new(Program.CreateWebHostBuilder(args)); using HttpClient client = server.CreateClient(); { - HttpRequestMessage request = new(HttpMethod.Get, "/api/openapi"); + string requestUrl = role is null ? "/api/openapi" : $"/api/openapi/{role}"; + HttpRequestMessage request = new(HttpMethod.Get, requestUrl); HttpResponseMessage response = await client.SendAsync(request); Stream responseStream = await response.Content.ReadAsStreamAsync(); diff --git a/src/Service.Tests/OpenApiDocumentor/OperationFilteringTests.cs b/src/Service.Tests/OpenApiDocumentor/OperationFilteringTests.cs new file mode 100644 index 0000000000..ef2c4fc5e8 --- /dev/null +++ b/src/Service.Tests/OpenApiDocumentor/OperationFilteringTests.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.OpenApi.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.OpenApiIntegration +{ + /// + /// Tests validating OpenAPI document filters REST methods based on entity permissions. + /// + [TestCategory(TestCategory.MSSQL)] + [TestClass] + public class OperationFilteringTests + { + private const string CONFIG_FILE = "operation-filter-config.MsSql.json"; + private const string DB_ENV = TestCategory.MSSQL; + + /// + /// Validates read-only entity shows only GET operations. + /// + [TestMethod] + public async Task ReadOnlyEntity_ShowsOnlyGetOperations() + { + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "anonymous", Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }) + }; + + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions); + + foreach (var path in doc.Paths) + { + Assert.IsTrue(path.Value.Operations.ContainsKey(OperationType.Get), $"GET missing at {path.Key}"); + Assert.IsFalse(path.Value.Operations.ContainsKey(OperationType.Post), $"POST should not exist at {path.Key}"); + Assert.IsFalse(path.Value.Operations.ContainsKey(OperationType.Put), $"PUT should not exist at {path.Key}"); + Assert.IsFalse(path.Value.Operations.ContainsKey(OperationType.Patch), $"PATCH should not exist at {path.Key}"); + Assert.IsFalse(path.Value.Operations.ContainsKey(OperationType.Delete), $"DELETE should not exist at {path.Key}"); + } + } + + /// + /// Validates wildcard (*) permission shows all CRUD operations. + /// + [TestMethod] + public async Task WildcardPermission_ShowsAllOperations() + { + OpenApiDocument doc = await GenerateDocumentWithPermissions(OpenApiTestBootstrap.CreateBasicPermissions()); + + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get))); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post))); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put))); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch))); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Delete))); + } + + /// + /// Validates entity with no permissions is omitted from OpenAPI document. + /// + [TestMethod] + public async Task EntityWithNoPermissions_IsOmittedFromDocument() + { + // Entity with no permissions + Entity entityNoPerms = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: [], + Mappings: null, + Relationships: null); + + // Entity with permissions for reference + Entity entityWithPerms = new( + Source: new("publishers", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: OpenApiTestBootstrap.CreateBasicPermissions(), + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary + { + { "book", entityNoPerms }, + { "publisher", entityWithPerms } + }); + + OpenApiDocument doc = await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync(entities, CONFIG_FILE, DB_ENV); + + Assert.IsFalse(doc.Paths.Keys.Any(k => k.Contains("book")), "Entity with no permissions should not have paths"); + Assert.IsFalse(doc.Tags.Any(t => t.Name == "book"), "Entity with no permissions should not have tag"); + Assert.IsTrue(doc.Paths.Keys.Any(k => k.Contains("publisher")), "Entity with permissions should have paths"); + } + + /// + /// Validates superset of permissions across roles is shown. + /// + [TestMethod] + public async Task MixedRolePermissions_ShowsSupersetOfOperations() + { + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "anonymous", Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }), + new EntityPermission(Role: "authenticated", Actions: new[] { new EntityAction(EntityActionOperation.Create, null, new()) }) + }; + + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions); + + // Should have both GET (from anonymous read) and POST (from authenticated create) + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), "GET should exist from anonymous read"); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), "POST should exist from authenticated create"); + } + + private static async Task GenerateDocumentWithPermissions(EntityPermission[] permissions) + { + Entity entity = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: permissions, + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary { { "book", entity } }); + return await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync(entities, CONFIG_FILE, DB_ENV); + } + } +} diff --git a/src/Service.Tests/OpenApiDocumentor/ParameterValidationTests.cs b/src/Service.Tests/OpenApiDocumentor/ParameterValidationTests.cs index 7c0e0225ae..5caed5a6b9 100644 --- a/src/Service.Tests/OpenApiDocumentor/ParameterValidationTests.cs +++ b/src/Service.Tests/OpenApiDocumentor/ParameterValidationTests.cs @@ -117,8 +117,13 @@ public async Task TestQueryParametersExcludedFromNonReadOperationsOnTablesAndVie OpenApiPathItem pathWithouId = openApiDocument.Paths[$"/{entityName}"]; Assert.IsTrue(pathWithouId.Operations.ContainsKey(OperationType.Post)); Assert.IsFalse(pathWithouId.Operations[OperationType.Post].Parameters.Any(param => param.In is ParameterLocation.Query)); - Assert.IsFalse(pathWithouId.Operations.ContainsKey(OperationType.Put)); - Assert.IsFalse(pathWithouId.Operations.ContainsKey(OperationType.Patch)); + + // With keyless PUT/PATCH support, PUT and PATCH operations are present on the base path + // for entities with auto-generated primary keys. Validate they don't have query parameters. + Assert.IsTrue(pathWithouId.Operations.ContainsKey(OperationType.Put)); + Assert.IsFalse(pathWithouId.Operations[OperationType.Put].Parameters.Any(param => param.In is ParameterLocation.Query)); + Assert.IsTrue(pathWithouId.Operations.ContainsKey(OperationType.Patch)); + Assert.IsFalse(pathWithouId.Operations[OperationType.Patch].Parameters.Any(param => param.In is ParameterLocation.Query)); Assert.IsFalse(pathWithouId.Operations.ContainsKey(OperationType.Delete)); // Assert that Query Parameters Excluded From NonReadOperations for path with id. diff --git a/src/Service.Tests/OpenApiDocumentor/PathValidationTests.cs b/src/Service.Tests/OpenApiDocumentor/PathValidationTests.cs index 5f478b3b80..7ba138de52 100644 --- a/src/Service.Tests/OpenApiDocumentor/PathValidationTests.cs +++ b/src/Service.Tests/OpenApiDocumentor/PathValidationTests.cs @@ -38,7 +38,6 @@ public class PathValidationTests [DataRow("entity", "//customEntityPath", "/customEntityPath", DisplayName = "Entity REST path has two leading slashes - REST path override used.")] [DataRow("entity", "///customEntityPath", "/customEntityPath", DisplayName = "Entity REST path has many leading slashes - REST path override used.")] [DataRow("entity", "customEntityPath", "/customEntityPath", DisplayName = "Entity REST path has no leading slash(es) - REST path override used.")] - [DataRow("entity", "", "/entity", DisplayName = "Entity REST path is an emtpy string - top level entity name used.")] [DataRow("entity", null, "/entity", DisplayName = "Entity REST path is null - top level entity name used.")] [DataTestMethod] public async Task ValidateEntityRestPath(string entityName, string configuredRestPath, string expectedOpenApiPath) diff --git a/src/Service.Tests/OpenApiDocumentor/RequestBodyStrictTests.cs b/src/Service.Tests/OpenApiDocumentor/RequestBodyStrictTests.cs new file mode 100644 index 0000000000..ccbe50ddc6 --- /dev/null +++ b/src/Service.Tests/OpenApiDocumentor/RequestBodyStrictTests.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.OpenApi.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.OpenApiIntegration +{ + /// + /// Tests validating OpenAPI schema correctly applies request-body-strict setting. + /// + [TestCategory(TestCategory.MSSQL)] + [TestClass] + public class RequestBodyStrictTests + { + private const string CONFIG_FILE = "request-body-strict-config.MsSql.json"; + private const string DB_ENV = TestCategory.MSSQL; + + /// + /// Validates that when request-body-strict is true (default), request body schemas + /// have additionalProperties set to false. + /// + [TestMethod] + public async Task RequestBodyStrict_True_DisallowsExtraFields() + { + OpenApiDocument doc = await GenerateDocumentWithPermissions( + OpenApiTestBootstrap.CreateBasicPermissions(), + requestBodyStrict: true); + + // Request body schemas should have additionalProperties = false + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book_NoAutoPK"), "POST request body schema should exist"); + Assert.IsFalse(doc.Components.Schemas["book_NoAutoPK"].AdditionalPropertiesAllowed, "POST request body should not allow extra fields in strict mode"); + + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book_NoPK"), "PUT/PATCH request body schema should exist"); + Assert.IsFalse(doc.Components.Schemas["book_NoPK"].AdditionalPropertiesAllowed, "PUT/PATCH request body should not allow extra fields in strict mode"); + + // Response body schema should allow extra fields (not a request body) + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book"), "Response body schema should exist"); + Assert.IsTrue(doc.Components.Schemas["book"].AdditionalPropertiesAllowed, "Response body should allow extra fields"); + } + + /// + /// Validates that when request-body-strict is false, request body schemas + /// have additionalProperties set to true. + /// + [TestMethod] + public async Task RequestBodyStrict_False_AllowsExtraFields() + { + OpenApiDocument doc = await GenerateDocumentWithPermissions( + OpenApiTestBootstrap.CreateBasicPermissions(), + requestBodyStrict: false); + + // Request body schemas should have additionalProperties = true + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book_NoAutoPK"), "POST request body schema should exist"); + Assert.IsTrue(doc.Components.Schemas["book_NoAutoPK"].AdditionalPropertiesAllowed, "POST request body should allow extra fields in non-strict mode"); + + Assert.IsTrue(doc.Components.Schemas.ContainsKey("book_NoPK"), "PUT/PATCH request body schema should exist"); + Assert.IsTrue(doc.Components.Schemas["book_NoPK"].AdditionalPropertiesAllowed, "PUT/PATCH request body should allow extra fields in non-strict mode"); + } + + private static async Task GenerateDocumentWithPermissions(EntityPermission[] permissions, bool? requestBodyStrict = null) + { + Entity entity = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: permissions, + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary { { "book", entity } }); + return await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync(entities, CONFIG_FILE, DB_ENV, requestBodyStrict); + } + } +} diff --git a/src/Service.Tests/OpenApiDocumentor/RoleIsolationTests.cs b/src/Service.Tests/OpenApiDocumentor/RoleIsolationTests.cs new file mode 100644 index 0000000000..31f4e19f3f --- /dev/null +++ b/src/Service.Tests/OpenApiDocumentor/RoleIsolationTests.cs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.OpenApi.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.OpenApiIntegration +{ + /// + /// Tests validating OpenAPI document correctly isolates permissions between roles. + /// + [TestCategory(TestCategory.MSSQL)] + [TestClass] + public class RoleIsolationTests + { + private const string CONFIG_FILE = "role-isolation-config.MsSql.json"; + private const string DB_ENV = TestCategory.MSSQL; + + /// + /// Validates that anonymous role is distinct from superset (no role specified). + /// When two roles have different permissions, the superset should contain both, + /// but the anonymous-specific view should only contain anonymous permissions. + /// + [TestMethod] + public async Task AnonymousRole_IsDistinctFromSuperset() + { + // Anonymous can only read, authenticated can create/update/delete + EntityPermission[] permissions = new[] + { + new EntityPermission( + Role: "anonymous", + Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }), + new EntityPermission( + Role: "authenticated", + Actions: new[] { + new EntityAction(EntityActionOperation.Create, null, new()), + new EntityAction(EntityActionOperation.Update, null, new()), + new EntityAction(EntityActionOperation.Delete, null, new()) + }) + }; + + // Superset (no role) should have all operations + OpenApiDocument supersetDoc = await GenerateDocumentWithPermissions(permissions); + Assert.IsTrue(supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), "Superset should have GET"); + Assert.IsTrue(supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), "Superset should have POST"); + Assert.IsTrue(supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), "Superset should have PUT"); + Assert.IsTrue(supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), "Superset should have PATCH"); + Assert.IsTrue(supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Delete)), "Superset should have DELETE"); + } + + /// + /// Validates competing roles don't leak operations to each other. + /// When one role has read-only and another has write-only, each role's + /// OpenAPI should only show their specific permissions. + /// + [TestMethod] + public async Task CompetingRoles_DoNotLeakOperations() + { + // Role1 can only read, Role2 can only create + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "reader", Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }), + new EntityPermission(Role: "writer", Actions: new[] { new EntityAction(EntityActionOperation.Create, null, new()) }) + }; + + // The superset should have both GET and POST + OpenApiDocument supersetDoc = await GenerateDocumentWithPermissions(permissions); + Assert.IsTrue( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), + "Superset should have GET from reader"); + Assert.IsTrue( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), + "Superset should have POST from writer"); + + // Neither role alone should have all operations - they don't leak + // This test confirms the superset correctly combines permissions while + // the individual role filtering (when implemented for direct calls) would not + Assert.IsFalse( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), + "No role has PUT, superset should not have it"); + Assert.IsFalse( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), + "No role has PATCH, superset should not have it"); + Assert.IsFalse( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Delete)), + "No role has DELETE, superset should not have it"); + } + + /// + /// Validates that PUT/PATCH require both Create and Update permissions. + /// Since PUT/PATCH can create if missing (upsert), both permissions are needed at runtime. + /// + [TestMethod] + public async Task PutPatchOperations_RequireBothCreateAndUpdatePermissions() + { + // Test 1: Only Create permission - should NOT have PUT/PATCH + EntityPermission[] createOnly = new[] + { + new EntityPermission(Role: "creator", Actions: new[] { new EntityAction(EntityActionOperation.Create, null, new()) }) + }; + OpenApiDocument docCreateOnly = await GenerateDocumentWithPermissions(createOnly); + Assert.IsTrue(docCreateOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), "Should have POST with Create permission"); + Assert.IsFalse(docCreateOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), "Should NOT have PUT with only Create permission"); + Assert.IsFalse(docCreateOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), "Should NOT have PATCH with only Create permission"); + + // Test 2: Only Update permission - should NOT have PUT/PATCH + EntityPermission[] updateOnly = new[] + { + new EntityPermission(Role: "updater", Actions: new[] { new EntityAction(EntityActionOperation.Update, null, new()) }) + }; + OpenApiDocument docUpdateOnly = await GenerateDocumentWithPermissions(updateOnly); + Assert.IsFalse(docUpdateOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), "Should NOT have PUT with only Update permission"); + Assert.IsFalse(docUpdateOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), "Should NOT have PATCH with only Update permission"); + + // Test 3: Both Create and Update permissions - should have PUT/PATCH + EntityPermission[] createAndUpdate = new[] + { + new EntityPermission( + Role: "editor", + Actions: new[] { + new EntityAction(EntityActionOperation.Create, null, new()), + new EntityAction(EntityActionOperation.Update, null, new()) + }) + }; + OpenApiDocument docBoth = await GenerateDocumentWithPermissions(createAndUpdate); + Assert.IsTrue(docBoth.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), "Should have POST with Create permission"); + Assert.IsTrue(docBoth.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), "Should have PUT with both Create and Update permissions"); + Assert.IsTrue(docBoth.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), "Should have PATCH with both Create and Update permissions"); + } + + /// + /// Validates competing roles don't leak fields to each other. + /// When one role has access to field A and another has access to field B, + /// the superset should have both, but individual role filtering should not leak. + /// + [TestMethod] + public async Task CompetingRoles_DoNotLeakFields() + { + // Reader can see 'id', writer can see 'title' + EntityActionFields readerFields = new(Exclude: new HashSet(), Include: new HashSet { "id" }); + EntityActionFields writerFields = new(Exclude: new HashSet(), Include: new HashSet { "title" }); + EntityPermission[] permissions = new[] + { + new EntityPermission(Role: "reader", Actions: new[] { new EntityAction(EntityActionOperation.Read, readerFields, new()) }), + new EntityPermission(Role: "writer", Actions: new[] { new EntityAction(EntityActionOperation.Create, writerFields, new()) }) + }; + + // The superset should have both fields + OpenApiDocument supersetDoc = await GenerateDocumentWithPermissions(permissions); + Assert.IsTrue(supersetDoc.Components.Schemas.ContainsKey("book"), "Schema should exist"); + Assert.IsTrue(supersetDoc.Components.Schemas["book"].Properties.ContainsKey("id"), "Superset should have 'id' from reader"); + Assert.IsTrue(supersetDoc.Components.Schemas["book"].Properties.ContainsKey("title"), "Superset should have 'title' from writer"); + + // Reader role should only see 'id', not 'title' + OpenApiDocument readerDoc = await GenerateDocumentWithPermissions(permissions, role: "reader"); + Assert.IsTrue(readerDoc.Components.Schemas.ContainsKey("book"), "Reader schema should exist"); + Assert.IsTrue(readerDoc.Components.Schemas["book"].Properties.ContainsKey("id"), "Reader should see 'id'"); + Assert.IsFalse(readerDoc.Components.Schemas["book"].Properties.ContainsKey("title"), "Reader should NOT see 'title'"); + + // Writer role should only see 'title', not 'id' + OpenApiDocument writerDoc = await GenerateDocumentWithPermissions(permissions, role: "writer"); + Assert.IsTrue(writerDoc.Components.Schemas.ContainsKey("book"), "Writer schema should exist"); + Assert.IsTrue(writerDoc.Components.Schemas["book"].Properties.ContainsKey("title"), "Writer should see 'title'"); + Assert.IsFalse(writerDoc.Components.Schemas["book"].Properties.ContainsKey("id"), "Writer should NOT see 'id'"); + } + + private static async Task GenerateDocumentWithPermissions(EntityPermission[] permissions, string role = null) + { + Entity entity = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: permissions, + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary { { "book", entity } }); + return await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync(entities, CONFIG_FILE, DB_ENV, requestBodyStrict: null, role: role); + } + } +} diff --git a/src/Service.Tests/OpenApiDocumentor/RoleSpecificEndpointTests.cs b/src/Service.Tests/OpenApiDocumentor/RoleSpecificEndpointTests.cs new file mode 100644 index 0000000000..2e590060ba --- /dev/null +++ b/src/Service.Tests/OpenApiDocumentor/RoleSpecificEndpointTests.cs @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.OpenApi.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.OpenApiIntegration +{ + /// + /// Tests for role-specific OpenAPI endpoint functionality including + /// caching and case-insensitivity. + /// + [TestCategory(TestCategory.MSSQL)] + [TestClass] + public class RoleSpecificEndpointTests + { + private const string CONFIG_FILE = "role-specific-endpoint-config.MsSql.json"; + private const string DB_ENV = TestCategory.MSSQL; + + /// + /// Validates that role-specific OpenAPI documents are properly generated + /// and contain expected content. + /// + [TestMethod] + public async Task RoleSpecificDocument_GeneratesCorrectly() + { + EntityPermission[] permissions = new[] + { + new EntityPermission( + Role: "reader", + Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }) + }; + + // Generate document for 'reader' role + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions, role: "reader"); + + Assert.IsNotNull(doc, "Document should not be null"); + Assert.IsTrue(doc.Paths.Count > 0, "Document should contain paths"); + Assert.IsTrue(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), "Reader role should have GET"); + Assert.IsFalse(doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), "Reader role should NOT have POST"); + } + + /// + /// Validates that role names are case-insensitive when matching. + /// + [DataTestMethod] + [DataRow("reader")] + [DataRow("READER")] + [DataRow("Reader")] + [DataRow("rEaDeR")] + public async Task RoleSpecificDocument_IsCaseInsensitive(string roleVariant) + { + EntityPermission[] permissions = new[] + { + new EntityPermission( + Role: "reader", + Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }) + }; + + OpenApiDocument doc = await GenerateDocumentWithPermissions(permissions, role: roleVariant); + + Assert.IsNotNull(doc, $"Document for role variant '{roleVariant}' should not be null"); + Assert.IsTrue(doc.Paths.Count > 0, "Document should contain paths"); + Assert.IsTrue( + doc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), + $"GET should be available for role '{roleVariant}'"); + } + + /// + /// Validates that superset document contains operations from all roles + /// while role-specific documents only contain that role's operations. + /// + [TestMethod] + public async Task SupersetDocument_ContainsAllRoleOperations() + { + EntityPermission[] permissions = new[] + { + new EntityPermission( + Role: "reader", + Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }), + new EntityPermission( + Role: "writer", + Actions: new[] { + new EntityAction(EntityActionOperation.Create, null, new()), + new EntityAction(EntityActionOperation.Update, null, new()) + }) + }; + + // Superset (no role) should have all operations + OpenApiDocument supersetDoc = await GenerateDocumentWithPermissions(permissions); + Assert.IsTrue( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), + "Superset should have GET"); + Assert.IsTrue( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), + "Superset should have POST"); + Assert.IsTrue( + supersetDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), + "Superset should have PUT"); + + // Reader role should only have GET + OpenApiDocument readerDoc = await GenerateDocumentWithPermissions(permissions, role: "reader"); + Assert.IsTrue( + readerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), + "Reader should have GET"); + Assert.IsFalse( + readerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), + "Reader should NOT have POST"); + + // Writer role should only have POST, PUT, PATCH + OpenApiDocument writerDoc = await GenerateDocumentWithPermissions(permissions, role: "writer"); + Assert.IsTrue( + writerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), + "Writer should have POST"); + Assert.IsTrue( + writerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), + "Writer should have PUT"); + Assert.IsFalse( + writerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Get)), + "Writer should NOT have GET"); + Assert.IsFalse( + writerDoc.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Delete)), + "Writer should NOT have DELETE"); + } + + /// + /// Validates that request body schemas (_NoAutoPK, _NoPK) are only generated + /// when mutation operations (POST, PUT, PATCH) are available. + /// This optimization reduces document size for read-only entities. + /// + [TestMethod] + public async Task RequestBodySchemas_OnlyGeneratedForMutationOperations() + { + // Create+Update permissions enable PUT/PATCH (mutation operations present) + EntityPermission[] permissionsWithUpdate = new[] + { + new EntityPermission( + Role: "editor", + Actions: new[] { + new EntityAction(EntityActionOperation.Create, null, new()), + new EntityAction(EntityActionOperation.Update, null, new()) + }) + }; + + OpenApiDocument docWithMutations = await GenerateDocumentWithPermissions(permissionsWithUpdate); + Assert.IsTrue( + docWithMutations.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), + "Should have PUT"); + Assert.IsTrue( + docWithMutations.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), + "Should have PATCH"); + // Request body schemas should be present for mutation operations + Assert.IsTrue( + docWithMutations.Components.Schemas.ContainsKey("book_NoAutoPK"), + "Should have request body schema for mutations"); + Assert.IsTrue( + docWithMutations.Components.Schemas.ContainsKey("book_NoPK"), + "Should have alternate request body schema for mutations"); + + // Read-only permissions - no mutation operations + EntityPermission[] permissionsReadOnly = new[] + { + new EntityPermission( + Role: "reader", + Actions: new[] { new EntityAction(EntityActionOperation.Read, null, new()) }) + }; + + OpenApiDocument docReadOnly = await GenerateDocumentWithPermissions(permissionsReadOnly); + Assert.IsFalse( + docReadOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Put)), + "Should NOT have PUT"); + Assert.IsFalse( + docReadOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Patch)), + "Should NOT have PATCH"); + Assert.IsFalse( + docReadOnly.Paths.Any(p => p.Value.Operations.ContainsKey(OperationType.Post)), + "Should NOT have POST"); + // Request body schemas should NOT be generated for read-only entities (optimization) + Assert.IsFalse( + docReadOnly.Components.Schemas.ContainsKey("book_NoAutoPK"), + "Should NOT have request body schema for read-only entity"); + Assert.IsFalse( + docReadOnly.Components.Schemas.ContainsKey("book_NoPK"), + "Should NOT have alternate request body schema for read-only entity"); + } + + private static async Task GenerateDocumentWithPermissions( + EntityPermission[] permissions, + string role = null) + { + Entity entity = new( + Source: new("books", EntitySourceType.Table, null, null), + Fields: null, + GraphQL: new(null, null, false), + Rest: new(EntityRestOptions.DEFAULT_SUPPORTED_VERBS), + Permissions: permissions, + Mappings: null, + Relationships: null); + + RuntimeEntities entities = new(new Dictionary { { "book", entity } }); + return await OpenApiTestBootstrap.GenerateOpenApiDocumentAsync( + entities, + CONFIG_FILE, + DB_ENV, + requestBodyStrict: null, + role: role); + } + } +} diff --git a/src/Service.Tests/SqlTests/RestApiTests/Find/DwSqlFindApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Find/DwSqlFindApiTests.cs index 8c78a27061..0f51bf8af8 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Find/DwSqlFindApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Find/DwSqlFindApiTests.cs @@ -221,6 +221,12 @@ public class DwSqlFindApiTests : FindApiTestBase $"WHERE (NOT (id < 3) OR id < 4) OR NOT (title = 'Awesome book') " + $"FOR JSON PATH, INCLUDE_NULL_VALUES" }, + { + "FindTestWithFilterContainingSpecialCharacters", + $"SELECT * FROM { _integrationTableName } " + + $"WHERE title = 'SOME%CONN' " + + $"FOR JSON PATH, INCLUDE_NULL_VALUES" + }, { "FindTestWithPrimaryKeyContainingForeignKey", $"SELECT [id], [content] FROM reviews " + diff --git a/src/Service.Tests/SqlTests/RestApiTests/Find/FindApiTestBase.cs b/src/Service.Tests/SqlTests/RestApiTests/Find/FindApiTestBase.cs index 483d870d85..289454e3dc 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Find/FindApiTestBase.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Find/FindApiTestBase.cs @@ -693,6 +693,23 @@ await SetupAndRunRestApiTest( ); } + /// + /// Tests the REST Api for Find operation with a filter containing special characters + /// that need to be URL-encoded. Uses existing book with '%' character (SOME%CONN). + /// This validates that the fix for the double-decoding issue is working correctly. + /// + [TestMethod] + public async Task FindTestWithFilterContainingSpecialCharacters() + { + // Testing with SOME%CONN - the %25 is URL-encoded % + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: "?$filter=title%20eq%20%27SOME%25CONN%27", + entityNameOrPath: _integrationEntityName, + sqlQuery: GetQuery(nameof(FindTestWithFilterContainingSpecialCharacters)) + ); + } + /// /// Tests the REST Api for Find operation where we compare one field /// to the bool returned from another comparison. diff --git a/src/Service.Tests/SqlTests/RestApiTests/Find/MsSqlFindApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Find/MsSqlFindApiTests.cs index 6f43fb2073..6e611834c8 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Find/MsSqlFindApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Find/MsSqlFindApiTests.cs @@ -228,6 +228,12 @@ public class MsSqlFindApiTests : FindApiTestBase $"WHERE (NOT (id < 3) OR id < 4) OR NOT (title = 'Awesome book') " + $"FOR JSON PATH, INCLUDE_NULL_VALUES" }, + { + "FindTestWithFilterContainingSpecialCharacters", + $"SELECT * FROM { _integrationTableName } " + + $"WHERE title = 'SOME%CONN' " + + $"FOR JSON PATH, INCLUDE_NULL_VALUES" + }, { "FindTestWithPrimaryKeyContainingForeignKey", $"SELECT [id], [content] FROM reviews " + diff --git a/src/Service.Tests/SqlTests/RestApiTests/Find/MySqlFindApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Find/MySqlFindApiTests.cs index f9a3fdb764..0e7e89364c 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Find/MySqlFindApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Find/MySqlFindApiTests.cs @@ -397,6 +397,18 @@ ORDER BY id asc ) AS subq " }, + { + "FindTestWithFilterContainingSpecialCharacters", + @" + SELECT JSON_ARRAYAGG(JSON_OBJECT('id', id, 'title', title, 'publisher_id', publisher_id)) AS data + FROM ( + SELECT * + FROM " + _integrationTableName + @" + WHERE title = 'SOME%CONN' + ORDER BY id asc + ) AS subq + " + }, { "FindTestWithFilterQueryStringBoolResultFilter", @" diff --git a/src/Service.Tests/SqlTests/RestApiTests/Find/PostgreSqlFindApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Find/PostgreSqlFindApiTests.cs index 9abcfe88c2..76edca949e 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Find/PostgreSqlFindApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Find/PostgreSqlFindApiTests.cs @@ -411,6 +411,17 @@ SELECT json_agg(to_jsonb(subq)) AS data ORDER BY id asc ) AS subq" }, + { + "FindTestWithFilterContainingSpecialCharacters", + @" + SELECT json_agg(to_jsonb(subq)) AS data + FROM ( + SELECT * + FROM " + _integrationTableName + @" + WHERE title = 'SOME%CONN' + ORDER BY id asc + ) AS subq" + }, { "FindTestWithPrimaryKeyContainingForeignKey", @" diff --git a/src/Service.Tests/SqlTests/RestApiTests/Patch/MsSqlPatchApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Patch/MsSqlPatchApiTests.cs index eeb97badc9..e711b648dc 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Patch/MsSqlPatchApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Patch/MsSqlPatchApiTests.cs @@ -18,6 +18,13 @@ public class MsSqlPatchApiTests : PatchApiTestBase { private static Dictionary _queryMap = new() { + { + "PatchOne_Insert_KeylessWithAutoGenPK_Test", + $"SELECT [id], [title], [publisher_id] FROM { _integrationTableName } " + + $"WHERE [id] = { STARTING_ID_FOR_TEST_INSERTS } AND [title] = 'My New Book' " + + $"AND [publisher_id] = 1234 " + + $"FOR JSON PATH, INCLUDE_NULL_VALUES, WITHOUT_ARRAY_WRAPPER" + }, { "PatchOne_Insert_NonAutoGenPK_Test", $"SELECT [id], [title], [issue_number] FROM [foo].{ _integration_NonAutoGenPK_TableName } " + diff --git a/src/Service.Tests/SqlTests/RestApiTests/Patch/MySqlPatchApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Patch/MySqlPatchApiTests.cs index de72fe8b80..65cf224e75 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Patch/MySqlPatchApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Patch/MySqlPatchApiTests.cs @@ -13,6 +13,17 @@ public class MySqlPatchApiTests : PatchApiTestBase { protected static Dictionary _queryMap = new() { + { + "PatchOne_Insert_KeylessWithAutoGenPK_Test", + @"SELECT JSON_OBJECT('id', id, 'title', title, 'publisher_id', publisher_id) AS data + FROM ( + SELECT id, title, publisher_id + FROM " + _integrationTableName + @" + WHERE id = " + STARTING_ID_FOR_TEST_INSERTS + @" + AND title = 'My New Book' AND publisher_id = 1234 + ) AS subq + " + }, { "PatchOne_Insert_NonAutoGenPK_Test", @"SELECT JSON_OBJECT('id', id, 'title', title, 'issue_number', issue_number ) AS data diff --git a/src/Service.Tests/SqlTests/RestApiTests/Patch/PatchApiTestBase.cs b/src/Service.Tests/SqlTests/RestApiTests/Patch/PatchApiTestBase.cs index 8aad5b0654..ba5971c02c 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Patch/PatchApiTestBase.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Patch/PatchApiTestBase.cs @@ -5,7 +5,6 @@ using System.Net; using System.Threading.Tasks; using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; @@ -346,6 +345,34 @@ public virtual Task PatchOneUpdateTestOnTableWithSecurityPolicy() return Task.CompletedTask; } + /// + /// Tests the PatchOne functionality with a REST PATCH request + /// without a primary key route on an entity with an auto-generated primary key. + /// With keyless PATCH support, ValidateUpsertRequestContext allows this because + /// all PK columns are auto-generated. The mutation engine then performs an insert + /// and succeeds with 201 Created. + /// + [TestMethod] + public virtual async Task PatchOne_Insert_KeylessWithAutoGenPK_Test() + { + string requestBody = @" + { + ""title"": ""My New Book"", + ""publisher_id"": 1234 + }"; + + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: null, + entityNameOrPath: _integrationEntityName, + sqlQuery: GetQuery(nameof(PatchOne_Insert_KeylessWithAutoGenPK_Test)), + operationType: EntityActionOperation.UpsertIncremental, + requestBody: requestBody, + expectedStatusCode: HttpStatusCode.Created, + expectedLocationHeader: string.Empty + ); + } + /// /// Tests successful execution of PATCH update requests on views /// when requests try to modify fields belonging to one base table @@ -931,8 +958,9 @@ await SetupAndRunRestApiTest( /// /// Tests the Patch functionality with a REST PATCH request - /// without a primary key route. We expect a failure and so - /// no sql query is provided. + /// without a primary key route. For non-auto-generated PK entities, + /// ValidateUpsertRequestContext detects the missing required + /// non-auto-generated PK field in the body and returns a BadRequest. /// [TestMethod] public virtual async Task PatchWithNoPrimaryKeyRouteTest() @@ -951,11 +979,45 @@ await SetupAndRunRestApiTest( operationType: EntityActionOperation.UpsertIncremental, requestBody: requestBody, exceptionExpected: true, - expectedErrorMessage: RequestValidator.PRIMARY_KEY_NOT_PROVIDED_ERR_MESSAGE, - expectedStatusCode: HttpStatusCode.BadRequest + expectedErrorMessage: "Invalid request body. Missing field in body: id.", + expectedStatusCode: HttpStatusCode.BadRequest, + expectedSubStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest.ToString() ); } + /// + /// Tests the Patch functionality with a REST PATCH request + /// without a primary key route on an entity with a composite non-auto-generated primary key, + /// where the body only contains a partial key. ValidateUpsertRequestContext detects the + /// missing non-auto-generated PK field and returns a BadRequest. + /// + [TestMethod] + public virtual async Task PatchWithNoPrimaryKeyRouteAndPartialCompositeKeyInBodyTest() + { + // Body only contains categoryid but not pieceid — both are required + // since neither is auto-generated. + string requestBody = @" + { + ""categoryid"": 100, + ""categoryName"": ""SciFi"", + ""piecesAvailable"": 5, + ""piecesRequired"": 3 + }"; + + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: null, + entityNameOrPath: _Composite_NonAutoGenPK_EntityPath, + sqlQuery: string.Empty, + operationType: EntityActionOperation.UpsertIncremental, + requestBody: requestBody, + exceptionExpected: true, + expectedErrorMessage: "Invalid request body. Missing field in body: pieceid.", + expectedStatusCode: HttpStatusCode.BadRequest, + expectedSubStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest.ToString() + ); + } + /// /// Test to validate failure of PATCH operation failing to satisfy the database policy for the update operation. /// (because a record exists for given PK). @@ -988,7 +1050,7 @@ await SetupAndRunRestApiTest( } /// - /// Test to validate failure of PATCH operation failing to satisfy the database policy for the update operation. + /// Test to validate failure of PATCH operation failing to satisfy the database policy for the insert operation. /// (because no record exists for given PK). /// [TestMethod] diff --git a/src/Service.Tests/SqlTests/RestApiTests/Patch/PostgreSqlPatchApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Patch/PostgreSqlPatchApiTests.cs index 441c96425c..0a808bae58 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Patch/PostgreSqlPatchApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Patch/PostgreSqlPatchApiTests.cs @@ -13,6 +13,18 @@ public class PostgreSqlPatchApiTests : PatchApiTestBase { protected static Dictionary _queryMap = new() { + { + "PatchOne_Insert_KeylessWithAutoGenPK_Test", + @" + SELECT to_jsonb(subq) AS data + FROM ( + SELECT id, title, publisher_id + FROM " + _integrationTableName + @" + WHERE id = " + STARTING_ID_FOR_TEST_INSERTS + @" + AND title = 'My New Book' AND publisher_id = 1234 + ) AS subq + " + }, { "PatchOne_Insert_Mapping_Test", @" diff --git a/src/Service.Tests/SqlTests/RestApiTests/Put/MsSqlPutApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Put/MsSqlPutApiTests.cs index 5b2745e203..7ae15bb510 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Put/MsSqlPutApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Put/MsSqlPutApiTests.cs @@ -18,6 +18,13 @@ public class MsSqlPutApiTests : PutApiTestBase { private static Dictionary _queryMap = new() { + { + "PutOne_Insert_KeylessWithAutoGenPK_Test", + $"SELECT [id], [title], [publisher_id] FROM { _integrationTableName } " + + $"WHERE [id] = { STARTING_ID_FOR_TEST_INSERTS } AND [title] = 'My New Book' " + + $"AND [publisher_id] = 1234 " + + $"FOR JSON PATH, INCLUDE_NULL_VALUES, WITHOUT_ARRAY_WRAPPER" + }, { "PutOne_Update_Test", $"SELECT [id], [title], [publisher_id] FROM { _integrationTableName } " + diff --git a/src/Service.Tests/SqlTests/RestApiTests/Put/MySqlPutApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Put/MySqlPutApiTests.cs index 7708186f65..89024053b4 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Put/MySqlPutApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Put/MySqlPutApiTests.cs @@ -13,6 +13,18 @@ public class MySqlPutApiTests : PutApiTestBase { protected static Dictionary _queryMap = new() { + { + "PutOne_Insert_KeylessWithAutoGenPK_Test", + @" + SELECT JSON_OBJECT('id', id, 'title', title, 'publisher_id', publisher_id) AS data + FROM ( + SELECT id, title, publisher_id + FROM " + _integrationTableName + @" + WHERE id = " + STARTING_ID_FOR_TEST_INSERTS + @" + AND title = 'My New Book' AND publisher_id = 1234 + ) AS subq + " + }, { "PutOne_Update_Test", @" diff --git a/src/Service.Tests/SqlTests/RestApiTests/Put/PostgreSqlPutApiTests.cs b/src/Service.Tests/SqlTests/RestApiTests/Put/PostgreSqlPutApiTests.cs index c9e527876b..e9f8bcaac1 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Put/PostgreSqlPutApiTests.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Put/PostgreSqlPutApiTests.cs @@ -14,6 +14,18 @@ public class PostgreSqlPutApiTests : PutApiTestBase { protected static Dictionary _queryMap = new() { + { + "PutOne_Insert_KeylessWithAutoGenPK_Test", + @" + SELECT to_jsonb(subq) AS data + FROM ( + SELECT id, title, publisher_id + FROM " + _integrationTableName + @" + WHERE id = " + STARTING_ID_FOR_TEST_INSERTS + @" + AND title = 'My New Book' AND publisher_id = 1234 + ) AS subq + " + }, { "PutOne_Update_Test", @" diff --git a/src/Service.Tests/SqlTests/RestApiTests/Put/PutApiTestBase.cs b/src/Service.Tests/SqlTests/RestApiTests/Put/PutApiTestBase.cs index eb6cfb3767..503a8388a4 100644 --- a/src/Service.Tests/SqlTests/RestApiTests/Put/PutApiTestBase.cs +++ b/src/Service.Tests/SqlTests/RestApiTests/Put/PutApiTestBase.cs @@ -227,6 +227,34 @@ public virtual Task PutOneUpdateTestOnTableWithSecurityPolicy() return Task.CompletedTask; } + /// + /// Tests the PutOne functionality with a REST PUT request + /// without a primary key route on an entity with an auto-generated primary key. + /// With keyless PUT support, ValidateUpsertRequestContext allows this because + /// all PK columns are auto-generated. The mutation engine then performs an insert + /// and succeeds with 201 Created. + /// + [TestMethod] + public virtual async Task PutOne_Insert_KeylessWithAutoGenPK_Test() + { + string requestBody = @" + { + ""title"": ""My New Book"", + ""publisher_id"": 1234 + }"; + + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: null, + entityNameOrPath: _integrationEntityName, + sqlQuery: GetQuery(nameof(PutOne_Insert_KeylessWithAutoGenPK_Test)), + operationType: EntityActionOperation.Upsert, + requestBody: requestBody, + expectedStatusCode: HttpStatusCode.Created, + expectedLocationHeader: string.Empty + ); + } + /// /// Tests the PutOne functionality with a REST PUT request using /// headers that include as a key "If-Match" with an item that does exist, @@ -998,8 +1026,9 @@ await SetupAndRunRestApiTest( /// /// Tests the Put functionality with a REST PUT request - /// without a primary key route. We expect a failure and so - /// no sql query is provided. + /// without a primary key route. For non-auto-generated PK entities, + /// ValidateUpsertRequestContext detects the missing required + /// non-auto-generated PK field in the body and returns a BadRequest. /// [TestMethod] public virtual async Task PutWithNoPrimaryKeyRouteTest() @@ -1018,11 +1047,75 @@ await SetupAndRunRestApiTest( operationType: EntityActionOperation.Upsert, requestBody: requestBody, exceptionExpected: true, + expectedErrorMessage: "Invalid request body. Missing field in body: id.", + expectedStatusCode: HttpStatusCode.BadRequest, + expectedSubStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest.ToString() + ); + } + + /// + /// Tests that a PUT request with If-Match header (strict update semantics) + /// still requires a primary key route. When If-Match is present, the operation + /// becomes Update (not Upsert), so it cannot be converted to Insert. + /// + [TestMethod] + public virtual async Task PutWithNoPrimaryKeyRouteAndIfMatchHeaderTest() + { + Dictionary headerDictionary = new(); + headerDictionary.Add("If-Match", "*"); + string requestBody = @" + { + ""title"": ""Batman Returns"", + ""publisher_id"": 1234 + }"; + + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: null, + entityNameOrPath: _integrationEntityName, + sqlQuery: string.Empty, + operationType: EntityActionOperation.Upsert, + headers: new HeaderDictionary(headerDictionary), + requestBody: requestBody, + exceptionExpected: true, expectedErrorMessage: RequestValidator.PRIMARY_KEY_NOT_PROVIDED_ERR_MESSAGE, expectedStatusCode: HttpStatusCode.BadRequest ); } + /// + /// Tests the Put functionality with a REST PUT request + /// without a primary key route on an entity with a composite non-auto-generated primary key, + /// where the body only contains a partial key. ValidateUpsertRequestContext detects the + /// missing non-auto-generated PK field and returns a BadRequest. + /// + [TestMethod] + public virtual async Task PutWithNoPrimaryKeyRouteAndPartialCompositeKeyInBodyTest() + { + // Body only contains categoryid but not pieceid — both are required + // since neither is auto-generated. + string requestBody = @" + { + ""categoryid"": 100, + ""categoryName"": ""SciFi"", + ""piecesAvailable"": 5, + ""piecesRequired"": 3 + }"; + + await SetupAndRunRestApiTest( + primaryKeyRoute: string.Empty, + queryString: null, + entityNameOrPath: _Composite_NonAutoGenPK_EntityPath, + sqlQuery: string.Empty, + operationType: EntityActionOperation.Upsert, + requestBody: requestBody, + exceptionExpected: true, + expectedErrorMessage: "Invalid request body. Missing field in body: pieceid.", + expectedStatusCode: HttpStatusCode.BadRequest, + expectedSubStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest.ToString() + ); + } + /// /// Tests that a cast failure of primary key value type results in HTTP 400 Bad Request. /// e.g. Attempt to cast a string '{}' to the 'publisher_id' column type of int will fail. diff --git a/src/Service.Tests/SqlTests/SqlTestBase.cs b/src/Service.Tests/SqlTests/SqlTestBase.cs index 16e804f117..4e0fa5249c 100644 --- a/src/Service.Tests/SqlTests/SqlTestBase.cs +++ b/src/Service.Tests/SqlTests/SqlTestBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Abstractions; using System.Net; using System.Net.Http; using System.Net.Http.Json; @@ -279,6 +280,10 @@ protected static void SetUpSQLMetadataProvider(RuntimeConfigProvider runtimeConf _queryManagerFactory = new Mock(); Mock httpContextAccessor = new(); string dataSourceName = runtimeConfigProvider.GetConfig().DefaultDataSourceName; + IFileSystem fileSystem = new FileSystem(); + Mock> loggerValidator = new(); + RuntimeConfigValidator runtimeConfigValidator = new(runtimeConfigProvider, fileSystem, loggerValidator.Object); + switch (DatabaseEngine) { case TestCategory.POSTGRESQL: @@ -297,6 +302,7 @@ protected static void SetUpSQLMetadataProvider(RuntimeConfigProvider runtimeConf _sqlMetadataProvider = new PostgreSqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, _queryManagerFactory.Object, _sqlMetadataLogger, dataSourceName); @@ -317,6 +323,7 @@ protected static void SetUpSQLMetadataProvider(RuntimeConfigProvider runtimeConf _sqlMetadataProvider = new MsSqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, _queryManagerFactory.Object, _sqlMetadataLogger, dataSourceName); @@ -337,6 +344,7 @@ protected static void SetUpSQLMetadataProvider(RuntimeConfigProvider runtimeConf _sqlMetadataProvider = new MySqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, _queryManagerFactory.Object, _sqlMetadataLogger, dataSourceName); @@ -357,6 +365,7 @@ protected static void SetUpSQLMetadataProvider(RuntimeConfigProvider runtimeConf _sqlMetadataProvider = new MsSqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, _queryManagerFactory.Object, _sqlMetadataLogger, dataSourceName); diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index 119e6637c6..05561e4cf9 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -437,6 +437,115 @@ string relationshipEntity configValidator.ValidateRelationships(runtimeConfig, _metadataProviderFactory.Object); } + /// + /// Test method to verify that many-to-many relationships work correctly when the linking object + /// is in a custom schema (not dbo). This test validates that schema names are correctly compared + /// using case-insensitive comparison, which is important for SQL Server where schema names are + /// case-insensitive. + /// + [DataRow("mySchema.TEST_SOURCE_LINK", "mySchema", "TEST_SOURCE_LINK", DisplayName = "Linking object with custom schema")] + [DataRow("MYSCHEMA.TEST_SOURCE_LINK", "MYSCHEMA", "TEST_SOURCE_LINK", DisplayName = "Linking object with uppercase custom schema")] + [DataRow("myschema.test_source_link", "myschema", "test_source_link", DisplayName = "Linking object with lowercase schema and table")] + [DataTestMethod] + public void TestRelationshipWithLinkingObjectInCustomSchema( + string linkingObject, + string expectedSchema, + string expectedTable + ) + { + // Creating an EntityMap with two sample entities + Dictionary entityMap = GetSampleEntityMap( + sourceEntity: "SampleEntity1", + targetEntity: "SampleEntity2", + sourceFields: new string[] { "sourceField" }, + targetFields: new string[] { "targetField" }, + linkingObject: linkingObject, + linkingSourceFields: new string[] { "linkingSourceField" }, + linkingTargetFields: new string[] { "linkingTargetField" } + ); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null) + ), + Entities: new(entityMap) + ); + + // Mocking EntityToDatabaseObject - entities are in the custom schema as well + MockFileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader) { IsLateConfigured = true }; + RuntimeConfigValidator configValidator = new(provider, fileSystem, new Mock>().Object); + Mock _sqlMetadataProvider = new(); + + Dictionary mockDictionaryForEntityDatabaseObject = new() + { + { + "SampleEntity1", + new DatabaseTable(expectedSchema, "TEST_SOURCE1") + }, + { + "SampleEntity2", + new DatabaseTable(expectedSchema, "TEST_SOURCE2") + } + }; + + _sqlMetadataProvider.Setup(x => x.EntityToDatabaseObject).Returns(mockDictionaryForEntityDatabaseObject); + + // To mock the schema name and dbObjectName for linkingObject + _sqlMetadataProvider.Setup(x => + x.ParseSchemaAndDbTableName(linkingObject)).Returns((expectedSchema, expectedTable)); + + string discard; + _sqlMetadataProvider.Setup(x => x.TryGetExposedColumnName(It.IsAny(), It.IsAny(), out discard)).Returns(true); + + Mock _metadataProviderFactory = new(); + _metadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(_sqlMetadataProvider.Object); + + // Mock ForeignKeyPair to be defined in DB with the custom schema + // The schema comparison should be case-insensitive + // Use concrete DatabaseTable instances with differing casing so that + // Moq relies on DatabaseTable.Equals for argument matching. + // Linking table uses lowercase to ensure case-insensitive comparison is working. + DatabaseTable expectedLinkingTable = new(expectedSchema.ToLowerInvariant(), expectedTable.ToLowerInvariant()); + DatabaseTable expectedSource1Table = new(expectedSchema.ToUpperInvariant(), "TEST_SOURCE1"); + DatabaseTable expectedSource2Table = new(expectedSchema.ToUpperInvariant(), "TEST_SOURCE2"); + + _sqlMetadataProvider.Setup(x => + x.VerifyForeignKeyExistsInDB( + expectedLinkingTable, + expectedSource1Table + )).Returns(true); + + _sqlMetadataProvider.Setup(x => + x.VerifyForeignKeyExistsInDB( + expectedLinkingTable, + expectedSource2Table + )).Returns(true); + + // Validation should pass with custom schema + configValidator.ValidateRelationships(runtimeConfig, _metadataProviderFactory.Object); + + // Verify that VerifyForeignKeyExistsInDB is never called with 'dbo' schema, + // guarding against the original dbo-fallback regression. + _sqlMetadataProvider.Verify(x => + x.VerifyForeignKeyExistsInDB( + It.Is(t => string.Equals(t.SchemaName, "dbo", StringComparison.OrdinalIgnoreCase)), + It.IsAny() + ), Times.Never); + + _sqlMetadataProvider.Verify(x => + x.VerifyForeignKeyExistsInDB( + It.IsAny(), + It.Is(t => string.Equals(t.SchemaName, "dbo", StringComparison.OrdinalIgnoreCase)) + ), Times.Never); + } + /// /// Test method to check that an exception is thrown when the relationship is one-many /// or many-one (determined by the linking object being null), while both SourceFields @@ -2066,21 +2175,41 @@ public void ValidateRestMethodsForEntityInConfig( [DataTestMethod] [DataRow(true, "EntityA", "", true, "The rest path for entity: EntityA cannot be empty.", DisplayName = "Empty rest path configured for an entity fails config validation.")] - [DataRow(true, "EntityA", "entity?RestPath", true, "The rest path: entity?RestPath for entity: EntityA contains one or more reserved characters.", - DisplayName = "Rest path for an entity containing reserved character ? fails config validation.")] - [DataRow(true, "EntityA", "entity#RestPath", true, "The rest path: entity#RestPath for entity: EntityA contains one or more reserved characters.", - DisplayName = "Rest path for an entity containing reserved character ? fails config validation.")] - [DataRow(true, "EntityA", "entity[]RestPath", true, "The rest path: entity[]RestPath for entity: EntityA contains one or more reserved characters.", - DisplayName = "Rest path for an entity containing reserved character ? fails config validation.")] - [DataRow(true, "EntityA", "entity+Rest*Path", true, "The rest path: entity+Rest*Path for entity: EntityA contains one or more reserved characters.", + [DataRow(true, "EntityA", "entity?RestPath", true, "The rest path: entity?RestPath for entity: EntityA contains '?' which is reserved for query strings in URLs.", DisplayName = "Rest path for an entity containing reserved character ? fails config validation.")] - [DataRow(true, "Entity?A", null, true, "The rest path: Entity?A for entity: Entity?A contains one or more reserved characters.", + [DataRow(true, "EntityA", "entity#RestPath", true, "The rest path: entity#RestPath for entity: EntityA contains '#' which is reserved for URL fragments.", + DisplayName = "Rest path for an entity containing reserved character # fails config validation.")] + [DataRow(true, "EntityA", "entity[]RestPath", true, "The rest path: entity[]RestPath for entity: EntityA contains reserved characters that are not allowed in URL paths.", + DisplayName = "Rest path for an entity containing reserved character [] fails config validation.")] + [DataRow(true, "EntityA", "entity+Rest*Path", true, "The rest path: entity+Rest*Path for entity: EntityA contains reserved characters that are not allowed in URL paths.", + DisplayName = "Rest path for an entity containing reserved character +* fails config validation.")] + [DataRow(true, "Entity?A", null, true, "The rest path: Entity?A for entity: Entity?A contains '?' which is reserved for query strings in URLs.", DisplayName = "Entity name for an entity containing reserved character ? fails config validation.")] - [DataRow(true, "Entity&*[]A", null, true, "The rest path: Entity&*[]A for entity: Entity&*[]A contains one or more reserved characters.", - DisplayName = "Entity name containing reserved character ? fails config validation.")] + [DataRow(true, "Entity&*[]A", null, true, "The rest path: Entity&*[]A for entity: Entity&*[]A contains reserved characters that are not allowed in URL paths.", + DisplayName = "Entity name containing reserved character &*[] fails config validation.")] [DataRow(false, "EntityA", "entityRestPath", true, DisplayName = "Rest path correctly configured as a non-empty string without any reserved characters.")] [DataRow(false, "EntityA", "entityRest/?Path", false, DisplayName = "Rest path for an entity containing reserved character but with rest disabled passes config validation.")] + [DataRow(false, "EntityA", "shopping-cart/item", true, + DisplayName = "Rest path with sub-directory passes config validation.")] + [DataRow(false, "EntityA", "api/v1/books", true, + DisplayName = "Rest path with multiple sub-directories passes config validation.")] + [DataRow(true, "EntityA", "entity\\path", true, "The rest path: entity\\path for entity: EntityA contains a backslash (\\). Use forward slash (/) for path separators.", + DisplayName = "Rest path with backslash fails config validation with helpful message.")] + [DataRow(false, "EntityA", "/entity/path", true, + DisplayName = "Rest path with leading slash is trimmed and passes config validation.")] + [DataRow(true, "EntityA", "entity//path", true, "The rest path: entity//path for entity: EntityA contains empty path segments. Ensure there are no leading, consecutive, or trailing slashes.", + DisplayName = "Rest path with consecutive slashes fails config validation.")] + [DataRow(true, "EntityA", "entity/path/", true, "The rest path: entity/path/ for entity: EntityA contains empty path segments. Ensure there are no leading, consecutive, or trailing slashes.", + DisplayName = "Rest path with trailing slash fails config validation.")] + [DataRow(true, "EntityA", "entity /path", true, "The rest path: entity /path for entity: EntityA contains whitespace which is not allowed in URL paths.", + DisplayName = "Rest path with whitespace fails config validation with helpful message.")] + [DataRow(true, "EntityA", "entity%3Frest", true, "The rest path: entity%3Frest for entity: EntityA contains percent-encoding (%) which is not allowed. Use literal characters only.", + DisplayName = "Rest path with percent-encoded characters fails config validation.")] + [DataRow(true, "EntityA", "entity/../path", true, "The rest path: entity/../path for entity: EntityA contains path traversal patterns ('.' or '..') which are not allowed.", + DisplayName = "Rest path with dot-dot segments fails config validation.")] + [DataRow(true, "EntityA", "entity/./path", true, "The rest path: entity/./path for entity: EntityA contains path traversal patterns ('.' or '..') which are not allowed.", + DisplayName = "Rest path with dot segments fails config validation.")] public void ValidateRestPathForEntityInConfig( bool exceptionExpected, string entityName, @@ -2196,6 +2325,407 @@ public void ValidateUniqueRestPathsForEntitiesInConfig( } } + [TestMethod] + public void UserDelegatedAuthRequiresMssqlDatabaseType() + { + string runtimeConfigString = @"{ + ""$schema"": ""test_schema"", + ""data-source"": { + ""database-type"": ""postgresql"", + ""connection-string"": """ + SAMPLE_TEST_CONN_STRING + @""", + ""user-delegated-auth"": { + ""enabled"": true, + ""database-audience"": ""https://database.example"" + } + }, + ""runtime"": { + ""host"": { + ""authentication"": { + ""provider"": ""AzureAD"", + ""jwt"": { + ""audience"": ""api-audience"", + ""issuer"": ""https://login.microsoftonline.com/common/v2.0"" + } + } + } + }, + ""entities"": { } + }"; + + RuntimeConfigLoader.TryParseConfig(runtimeConfigString, out RuntimeConfig runtimeConfig); + MockFileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + Mock> loggerMock = new(); + RuntimeConfigValidator configValidator = new(provider, fileSystem, loggerMock.Object); + + DataApiBuilderException dabException = Assert.ThrowsException( + () => configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, loggerMock.Object)); + + Assert.AreEqual(expected: HttpStatusCode.ServiceUnavailable, actual: dabException.StatusCode); + Assert.AreEqual(expected: DataApiBuilderException.SubStatusCodes.ConfigValidationError, actual: dabException.SubStatusCode); + } + + [TestMethod] + public void UserDelegatedAuthRequiresDatabaseAudience() + { + string runtimeConfigString = @"{ + ""$schema"": ""test_schema"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": """ + SAMPLE_TEST_CONN_STRING + @""", + ""user-delegated-auth"": { + ""enabled"": true + } + }, + ""runtime"": { + ""host"": { + ""authentication"": { + ""provider"": ""AzureAD"", + ""jwt"": { + ""audience"": ""api-audience"", + ""issuer"": ""https://login.microsoftonline.com/common/v2.0"" + } + } + } + }, + ""entities"": { } + }"; + + RuntimeConfigLoader.TryParseConfig(runtimeConfigString, out RuntimeConfig runtimeConfig); + MockFileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + Mock> loggerMock = new(); + RuntimeConfigValidator configValidator = new(provider, fileSystem, loggerMock.Object); + + DataApiBuilderException dabException = Assert.ThrowsException( + () => configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, loggerMock.Object)); + + Assert.AreEqual(expected: HttpStatusCode.ServiceUnavailable, actual: dabException.StatusCode); + Assert.AreEqual(expected: DataApiBuilderException.SubStatusCodes.ConfigValidationError, actual: dabException.SubStatusCode); + } + + [TestMethod] + public void UserDelegatedAuthRequiresCachingDisabled() + { + // Arrange - Set environment variables for Azure AD credentials to ensure + // validation reaches the caching check (not failing on missing credentials) + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, "test-client-id"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, "test-client-secret"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, "test-tenant-id"); + + try + { + string runtimeConfigString = @"{ + ""$schema"": ""test_schema"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": """ + SAMPLE_TEST_CONN_STRING + @""", + ""user-delegated-auth"": { + ""enabled"": true, + ""database-audience"": ""https://database.example"" + } + }, + ""runtime"": { + ""cache"": { + ""enabled"": true + }, + ""host"": { + ""authentication"": { + ""provider"": ""AzureAD"", + ""jwt"": { + ""audience"": ""api-audience"", + ""issuer"": ""https://login.microsoftonline.com/common/v2.0"" + } + } + } + }, + ""entities"": { } + }"; + + RuntimeConfigLoader.TryParseConfig(runtimeConfigString, out RuntimeConfig runtimeConfig); + MockFileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + Mock> loggerMock = new(); + RuntimeConfigValidator configValidator = new(provider, fileSystem, loggerMock.Object); + + DataApiBuilderException dabException = Assert.ThrowsException( + () => configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, loggerMock.Object)); + + Assert.AreEqual(expected: RuntimeConfigValidator.USER_DELEGATED_AUTH_CACHING_ERR_MSG, actual: dabException.Message); + Assert.AreEqual(expected: HttpStatusCode.ServiceUnavailable, actual: dabException.StatusCode); + Assert.AreEqual(expected: DataApiBuilderException.SubStatusCodes.ConfigValidationError, actual: dabException.SubStatusCode); + } + finally + { + // Clean up environment variables + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, null); + } + } + + /// + /// Test to validate that user-delegated-auth with missing, empty, or whitespace database-audience throws an error. + /// + [DataTestMethod] + [DataRow(null, DisplayName = "Null audience should fail")] + [DataRow("", DisplayName = "Empty string audience should fail")] + [DataRow(" ", DisplayName = "Whitespace audience should fail")] + public void ValidateUserDelegatedAuth_InvalidDatabaseAudience_ThrowsError(string audience) + { + // Arrange + DataSource dataSource = new( + DatabaseType: DatabaseType.MSSQL, + ConnectionString: "Server=test;Database=test;", + Options: null) + { + UserDelegatedAuth = new UserDelegatedAuthOptions( + Enabled: true, + Provider: "EntraId", + DatabaseAudience: audience) + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new(Enabled: false) + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Act & Assert + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, new Mock().Object)); + + Assert.AreEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_MISSING_AUDIENCE_ERR_MSG, ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + + /// + /// Test to validate that user-delegated-auth with MSSQL, valid audience, and caching disabled passes validation. + /// + [TestMethod] + public void ValidateUserDelegatedAuth_ValidConfiguration_Succeeds() + { + // Arrange - Set environment variables for Azure AD credentials + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, "test-client-id"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, "test-client-secret"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, "test-tenant-id"); + + try + { + DataSource dataSource = new( + DatabaseType: DatabaseType.MSSQL, + ConnectionString: "Server=test;Database=test;", + Options: null) + { + UserDelegatedAuth = new UserDelegatedAuthOptions( + Enabled: true, + Provider: "EntraId", + DatabaseAudience: "https://database.windows.net/") + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new(Enabled: false) + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Act & Assert - validation should succeed without throwing + configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, new Mock().Object); + } + finally + { + // Clean up environment variables + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, null); + } + } + + /// + /// Test to validate that user-delegated-auth with missing DAB_OBO_CLIENT_ID, DAB_OBO_TENANT_ID, + /// or DAB_OBO_CLIENT_SECRET throws an error. + /// + [DataTestMethod] + [DataRow(null, "test-tenant", "test-secret", DisplayName = "Missing DAB_OBO_CLIENT_ID")] + [DataRow("test-client", null, "test-secret", DisplayName = "Missing DAB_OBO_TENANT_ID")] + [DataRow("test-client", "test-tenant", null, DisplayName = "Missing DAB_OBO_CLIENT_SECRET")] + [DataRow("", "test-tenant", "test-secret", DisplayName = "Empty DAB_OBO_CLIENT_ID")] + [DataRow("test-client", "", "test-secret", DisplayName = "Empty DAB_OBO_TENANT_ID")] + [DataRow("test-client", "test-tenant", "", DisplayName = "Empty DAB_OBO_CLIENT_SECRET")] + public void ValidateUserDelegatedAuth_MissingEnvVars_ThrowsError( + string clientId, string tenantId, string clientSecret) + { + // Arrange - Set environment variables (some may be null/empty to test validation) + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, clientId); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, tenantId); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, clientSecret); + + try + { + DataSource dataSource = new( + DatabaseType: DatabaseType.MSSQL, + ConnectionString: "Server=test;Database=test;", + Options: null) + { + UserDelegatedAuth = new UserDelegatedAuthOptions( + Enabled: true, + Provider: "EntraId", + DatabaseAudience: "https://database.windows.net/") + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new(Enabled: false) + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Act & Assert + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, new Mock().Object)); + + Assert.AreEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_MISSING_CREDENTIALS_ERR_MSG, ex.Message); + } + finally + { + // Clean up environment variables + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, null); + } + } + + /// + /// Test to validate that disabled user-delegated-auth does not trigger validation errors. + /// + [TestMethod] + public void ValidateUserDelegatedAuth_Disabled_SkipsValidation() + { + // Arrange - PostgreSQL with disabled user-delegated-auth should NOT fail + // Note: No environment variables set - should be fine since OBO is disabled + DataSource dataSource = new( + DatabaseType: DatabaseType.PostgreSQL, + ConnectionString: "Host=test;Database=test;", + Options: null) + { + UserDelegatedAuth = new UserDelegatedAuthOptions( + Enabled: false, // Disabled + Provider: "EntraId", + DatabaseAudience: null) // Missing audience, but should be ignored since disabled + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new(Enabled: true) // Caching enabled, but should be fine since OBO disabled + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Act & Assert - should not throw user-delegated-auth errors + try + { + configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, new Mock().Object); + } + catch (DataApiBuilderException ex) + { + // If an exception is thrown, it should NOT be one of the user-delegated-auth errors + Assert.AreNotEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_DATABASE_TYPE_ERR_MSG, ex.Message); + Assert.AreNotEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_MISSING_AUDIENCE_ERR_MSG, ex.Message); + Assert.AreNotEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_CACHING_ERR_MSG, ex.Message); + Assert.AreNotEqual(RuntimeConfigValidator.USER_DELEGATED_AUTH_MISSING_CREDENTIALS_ERR_MSG, ex.Message); + } + } + + /// + /// Test to validate that DAB_OBO_CLIENT_ID, DAB_OBO_TENANT_ID, and DAB_OBO_CLIENT_SECRET are required. + /// OBO authentication uses MSAL ConfidentialClientApplication with client secret for token exchange. + /// + [TestMethod] + public void ValidateUserDelegatedAuth_AllRequiredEnvVarsSet_PassesValidation() + { + // Arrange - Set all required environment variables for OBO authentication + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, "test-client-id"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, "test-tenant-id"); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, "test-client-secret"); + + try + { + DataSource dataSource = new( + DatabaseType: DatabaseType.MSSQL, + ConnectionString: "Server=test;Database=test;", + Options: null) + { + UserDelegatedAuth = new UserDelegatedAuthOptions( + Enabled: true, + Provider: "EntraId", + DatabaseAudience: "https://database.windows.net/") + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new(Enabled: false) + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Act & Assert - validation should succeed and not throw when all env vars are set + configValidator.ValidateDataSourceInConfig(runtimeConfig, fileSystem, new Mock().Object); + } + finally + { + // Clean up environment variables + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR, null); + Environment.SetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR, null); + } + } + /// /// Validates that the runtime base-route is well-formatted and does not contain any reserved characeters and /// can only be configured when authentication provider is Static Web Apps. @@ -2519,3 +3049,4 @@ private static RuntimeConfigValidator InitializeRuntimeConfigValidator() } } } + diff --git a/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs b/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs new file mode 100644 index 0000000000..0f9fcf2f5c --- /dev/null +++ b/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Service.HealthCheck; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Unit tests for health check utility methods. + /// + [TestClass] + public class HealthCheckUtilitiesUnitTests + { + /// + /// Tests that connection strings are properly normalized for supported database types. + /// + [TestMethod] + [DataRow( + DatabaseType.PostgreSQL, + "Host=localhost;Port=5432;Database=testdb;Username=testuser;Password=XXXX", + "Host=localhost", + "Database=testdb", + DisplayName = "PostgreSQL connection string normalization")] + [DataRow( + DatabaseType.MSSQL, + "Server=localhost;Database=testdb;User Id=testuser;Password=XXXX", + "Data Source=localhost", + "Initial Catalog=testdb", + DisplayName = "MSSQL connection string normalization")] + [DataRow( + DatabaseType.DWSQL, + "Server=localhost;Database=testdb;User Id=testuser;Password=XXXX", + "Data Source=localhost", + "Initial Catalog=testdb", + DisplayName = "DWSQL connection string normalization")] + [DataRow( + DatabaseType.MySQL, + "Server=localhost;Port=3306;Database=testdb;Uid=testuser;Pwd=XXXX", + "Server=localhost", + "Database=testdb", + DisplayName = "MySQL connection string normalization")] + public void NormalizeConnectionString_SupportedDatabases_Success( + DatabaseType dbType, + string connectionString, + string expectedServerPart, + string expectedDatabasePart) + { + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.IsNotNull(result); + Assert.IsTrue(result.Contains(expectedServerPart)); + Assert.IsTrue(result.Contains(expectedDatabasePart)); + } + + /// + /// Tests that unsupported database types return the original connection string. + /// + [TestMethod] + public void NormalizeConnectionString_UnsupportedType_ReturnsOriginal() + { + // Arrange + string connectionString = "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test"; + DatabaseType dbType = DatabaseType.CosmosDB_NoSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.AreEqual(connectionString, result); + } + + /// + /// Tests that malformed connection strings are handled gracefully. + /// + [TestMethod] + [DataRow(DatabaseType.PostgreSQL, true, DisplayName = "PostgreSQL malformed string with logger")] + [DataRow(DatabaseType.MSSQL, true, DisplayName = "MSSQL malformed string with logger")] + [DataRow(DatabaseType.MySQL, false, DisplayName = "MySQL malformed string without logger")] + public void NormalizeConnectionString_MalformedString_ReturnsOriginal( + DatabaseType dbType, + bool useLogger) + { + // Arrange + string malformedConnectionString = "InvalidConnectionString;NoEquals"; + Mock? mockLogger = useLogger ? new Mock() : null; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString( + malformedConnectionString, + dbType, + mockLogger?.Object); + + // Assert + Assert.AreEqual(malformedConnectionString, result); + if (useLogger && mockLogger != null) + { + mockLogger.Verify( + x => x.Log( + LogLevel.Warning, + It.IsAny(), + It.Is((v, t) => true), + It.IsAny(), + It.Is>((v, t) => true)), + Times.Once); + } + } + + /// + /// Tests that PostgreSQL connection strings with lowercase keywords are normalized correctly. + /// This is the specific bug that was reported - lowercase 'host' was not supported. + /// + [TestMethod] + public void NormalizeConnectionString_PostgreSQL_LowercaseKeywords_Success() + { + // Arrange + string connectionString = "host=localhost;port=5432;database=mydb;username=myuser;password=XXXX"; + DatabaseType dbType = DatabaseType.PostgreSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.IsNotNull(result); + // NpgsqlConnectionStringBuilder should normalize lowercase keywords to proper format + Assert.IsTrue(result.Contains("Host=localhost") || result.Contains("host=localhost")); + Assert.IsTrue(result.Contains("Database=mydb") || result.Contains("database=mydb")); + } + + /// + /// Tests that empty connection strings are handled gracefully. + /// + [TestMethod] + public void NormalizeConnectionString_EmptyString_ReturnsEmpty() + { + // Arrange + string connectionString = string.Empty; + DatabaseType dbType = DatabaseType.PostgreSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.AreEqual(string.Empty, result); + } + /// + /// Tests that GetCurrentRole returns "anonymous" when no auth headers are present. + /// + [TestMethod] + public void GetCurrentRole_NoHeaders_ReturnsAnonymous() + { + HealthCheckHelper helper = CreateHelper(); + string role = helper.GetCurrentRole(roleHeader: string.Empty, roleToken: string.Empty); + Assert.AreEqual(AuthorizationResolver.ROLE_ANONYMOUS, role); + } + + /// + /// Tests that GetCurrentRole returns "authenticated" when a bearer token is present but no role header is supplied. + /// + [TestMethod] + public void GetCurrentRole_BearerTokenOnly_ReturnsAuthenticated() + { + HealthCheckHelper helper = CreateHelper(); + string role = helper.GetCurrentRole(roleHeader: string.Empty, roleToken: "some-bearer-token"); + Assert.AreEqual(AuthorizationResolver.ROLE_AUTHENTICATED, role); + } + + /// + /// Tests that GetCurrentRole returns the explicit role value when the X-MS-API-ROLE header is provided. + /// + [TestMethod] + [DataRow("anonymous", DisplayName = "Explicit anonymous role header")] + [DataRow("authenticated", DisplayName = "Explicit authenticated role header")] + [DataRow("customrole", DisplayName = "Custom role header")] + public void GetCurrentRole_ExplicitRoleHeader_ReturnsHeaderValue(string explicitRole) + { + HealthCheckHelper helper = CreateHelper(); + string role = helper.GetCurrentRole(roleHeader: explicitRole, roleToken: string.Empty); + Assert.AreEqual(explicitRole, role); + } + + /// + /// Tests that the role header takes priority over the bearer token when both are present. + /// + [TestMethod] + public void GetCurrentRole_BothHeaderAndToken_RoleHeaderWins() + { + HealthCheckHelper helper = CreateHelper(); + string role = helper.GetCurrentRole(roleHeader: "customrole", roleToken: "some-bearer-token"); + Assert.AreEqual("customrole", role); + } + + /// + /// Tests that ReadRoleHeaders correctly reads X-MS-API-ROLE from the request. + /// + [TestMethod] + public void ReadRoleHeaders_WithRoleHeader_ReturnsRoleHeader() + { + HealthCheckHelper helper = CreateHelper(); + DefaultHttpContext context = new(); + context.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER] = "myrole"; + + (string roleHeader, string roleToken) = helper.ReadRoleHeaders(context); + + Assert.AreEqual("myrole", roleHeader); + Assert.AreEqual(string.Empty, roleToken); + } + + /// + /// Tests that ReadRoleHeaders returns empty strings when no headers are present. + /// + [TestMethod] + public void ReadRoleHeaders_NoHeaders_ReturnsEmpty() + { + HealthCheckHelper helper = CreateHelper(); + DefaultHttpContext context = new(); + + (string roleHeader, string roleToken) = helper.ReadRoleHeaders(context); + + Assert.AreEqual(string.Empty, roleHeader); + Assert.AreEqual(string.Empty, roleToken); + } + + /// + /// Tests that the cached health response does not reuse a previous caller's currentRole. + /// GetCurrentRole is a pure function: same input always produces same output, + /// and different inputs (representing different callers) produce different outputs. + /// + [TestMethod] + public void GetCurrentRole_CacheDoesNotLeakRole_DifferentCallersGetDifferentRoles() + { + HealthCheckHelper helper = CreateHelper(); + + // Simulate request 1 (anonymous, no headers) + string role1 = helper.GetCurrentRole(roleHeader: string.Empty, roleToken: string.Empty); + + // Simulate request 2 (authenticated, with bearer token) + string role2 = helper.GetCurrentRole(roleHeader: string.Empty, roleToken: "bearer-token"); + + // Simulate request 3 (explicit custom role) + string role3 = helper.GetCurrentRole(roleHeader: "adminrole", roleToken: string.Empty); + + Assert.AreEqual(AuthorizationResolver.ROLE_ANONYMOUS, role1); + Assert.AreEqual(AuthorizationResolver.ROLE_AUTHENTICATED, role2); + Assert.AreEqual("adminrole", role3); + } + + /// + /// Tests that parallel calls to GetCurrentRole with different roles do not bleed values across calls. + /// Validates the singleton-safe design (no shared mutable state). + /// + [TestMethod] + public async Task GetCurrentRole_ParallelRequests_NoRoleBleed() + { + HealthCheckHelper helper = CreateHelper(); + + // Run many parallel "requests" each with a unique role + int parallelCount = 50; + string[] expectedRoles = new string[parallelCount]; + string[] actualRoles = new string[parallelCount]; + + for (int i = 0; i < parallelCount; i++) + { + expectedRoles[i] = $"role-{i}"; + } + + List tasks = new(); + for (int i = 0; i < parallelCount; i++) + { + int index = i; + tasks.Add(Task.Run(() => + { + actualRoles[index] = helper.GetCurrentRole(roleHeader: expectedRoles[index], roleToken: string.Empty); + })); + } + + await Task.WhenAll(tasks); + + for (int i = 0; i < parallelCount; i++) + { + Assert.AreEqual(expectedRoles[i], actualRoles[i], $"Role bleed detected at index {i}: expected '{expectedRoles[i]}' but got '{actualRoles[i]}'"); + } + } + + private static HealthCheckHelper CreateHelper() + { + Mock> loggerMock = new(); + // HttpUtilities is not invoked by the methods under test (GetCurrentRole, ReadRoleHeaders), + // so passing null is safe here. + return new HealthCheckHelper(loggerMock.Object, null!); + } + } +} diff --git a/src/Service.Tests/UnitTests/McpTelemetryTests.cs b/src/Service.Tests/UnitTests/McpTelemetryTests.cs new file mode 100644 index 0000000000..9a8130f012 --- /dev/null +++ b/src/Service.Tests/UnitTests/McpTelemetryTests.cs @@ -0,0 +1,382 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Core.Telemetry; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Tests for MCP telemetry functionality. + /// + [TestClass] + public class McpTelemetryTests + { + private static ActivityListener? _activityListener; + private static readonly List _recordedActivities = new(); + + /// + /// Initialize activity listener before all tests. + /// + [ClassInitialize] + public static void ClassInitialize(TestContext context) + { + _activityListener = new ActivityListener + { + ShouldListenTo = (activitySource) => activitySource.Name == "DataApiBuilder", + Sample = (ref ActivityCreationOptions options) => ActivitySamplingResult.AllDataAndRecorded, + ActivityStarted = activity => { }, + ActivityStopped = activity => + { + _recordedActivities.Add(activity); + } + }; + ActivitySource.AddActivityListener(_activityListener); + } + + /// + /// Cleanup activity listener after all tests. + /// + [ClassCleanup] + public static void ClassCleanup() + { + _activityListener?.Dispose(); + } + + /// + /// Clear recorded activities before each test. + /// + [TestInitialize] + public void TestInitialize() + { + _recordedActivities.Clear(); + } + + #region Helpers + + /// + /// Creates and starts a new MCP tool execution activity, asserting it was created. + /// + private static Activity CreateActivity() + { + Activity? activity = TelemetryTracesHelper.DABActivitySource.StartActivity("mcp.tool.execute"); + Assert.IsNotNull(activity, "Activity should be created"); + return activity; + } + + /// + /// Stops the activity and returns the first recorded activity, asserting it was captured. + /// + private static Activity StopAndGetRecordedActivity(Activity activity) + { + activity.Stop(); + Activity? recorded = _recordedActivities.FirstOrDefault(); + Assert.IsNotNull(recorded, "Activity should be recorded"); + return recorded; + } + + /// + /// Builds a minimal service provider for tests that don't need real services. + /// + private static IServiceProvider CreateServiceProvider() + { + return new ServiceCollection().BuildServiceProvider(); + } + + /// + /// Creates a CallToolResult with the given text and error state. + /// + private static CallToolResult CreateToolResult(string text = "result", bool isError = false) + { + return new CallToolResult + { + Content = new List { new TextContentBlock { Text = text } }, + IsError = isError + }; + } + + /// + /// Creates an exception instance from a type name string, for use with DataRow tests. + /// + private static Exception CreateExceptionByTypeName(string typeName) + { + return typeName switch + { + nameof(OperationCanceledException) => new OperationCanceledException(), + nameof(UnauthorizedAccessException) => new UnauthorizedAccessException(), + nameof(ArgumentException) => new ArgumentException(), + nameof(InvalidOperationException) => new InvalidOperationException(), + _ => new Exception() + }; + } + + #endregion + + #region TrackMcpToolExecutionStarted + + /// + /// Test that TrackMcpToolExecutionStarted sets the expected tags for various input combinations, + /// including when optional parameters are null. + /// + [DataTestMethod] + [DataRow("read_records", "books", "read", null, DisplayName = "Sets entity, operation; no procedure")] + [DataRow("custom_proc", "CustomEntity", "execute", "dbo.CustomProc", DisplayName = "Custom tool with all tags including db.procedure")] + [DataRow("describe_entities", null, "describe", null, DisplayName = "Describe tool with null entity")] + [DataRow("custom_tool", "MyEntity", "execute", "schema.MyStoredProc", DisplayName = "Sets all four tags")] + public void TrackMcpToolExecutionStarted_SetsExpectedTags( + string toolName, string? entityName, string? operation, string? dbProcedure) + { + // Arrange & Act + using Activity activity = CreateActivity(); + activity.TrackMcpToolExecutionStarted( + toolName: toolName, + entityName: entityName, + operation: operation, + dbProcedure: dbProcedure); + + Activity recorded = StopAndGetRecordedActivity(activity); + + // Assert — tool name is always set + Assert.AreEqual(toolName, recorded.GetTagItem("mcp.tool.name")); + + // Optional tags: present only when supplied + Assert.AreEqual(entityName, recorded.GetTagItem("dab.entity")); + Assert.AreEqual(operation, recorded.GetTagItem("dab.operation")); + Assert.AreEqual(dbProcedure, recorded.GetTagItem("db.procedure")); + } + + #endregion + + #region TrackMcpToolExecutionFinished + + /// + /// Test that TrackMcpToolExecutionFinished sets status to OK. + /// + [TestMethod] + public void TrackMcpToolExecutionFinished_SetsStatusToOk() + { + using Activity activity = CreateActivity(); + activity.TrackMcpToolExecutionStarted(toolName: "read_records"); + activity.TrackMcpToolExecutionFinished(); + + Activity recorded = StopAndGetRecordedActivity(activity); + Assert.AreEqual(ActivityStatusCode.Ok, recorded.Status); + } + + /// + /// Test that TrackMcpToolExecutionFinishedWithException records exception and sets error status. + /// + [TestMethod] + public void TrackMcpToolExecutionFinishedWithException_RecordsExceptionAndSetsErrorStatus() + { + using Activity activity = CreateActivity(); + activity.TrackMcpToolExecutionStarted(toolName: "read_records"); + + Exception testException = new InvalidOperationException("Test exception"); + activity.TrackMcpToolExecutionFinishedWithException(testException, errorCode: McpTelemetryErrorCodes.EXECUTION_FAILED); + + Activity recorded = StopAndGetRecordedActivity(activity); + Assert.AreEqual(ActivityStatusCode.Error, recorded.Status); + Assert.AreEqual("Test exception", recorded.StatusDescription); + Assert.AreEqual("InvalidOperationException", recorded.GetTagItem("error.type")); + Assert.AreEqual("Test exception", recorded.GetTagItem("error.message")); + Assert.AreEqual(McpTelemetryErrorCodes.EXECUTION_FAILED, recorded.GetTagItem("error.code")); + + ActivityEvent? exceptionEvent = recorded.Events.FirstOrDefault(e => e.Name == "exception"); + Assert.IsNotNull(exceptionEvent, "Exception event should be recorded"); + } + + #endregion + + #region InferOperationFromTool + + /// + /// Test that InferOperationFromTool returns the correct operation for built-in and custom tools. + /// Built-in tools are mapped by name; custom tools always return "execute". + /// + [DataTestMethod] + // Built-in DML tool names mapped to operations + [DataRow(ToolType.BuiltIn, "read_records", "read", DisplayName = "Built-in: read_records -> read")] + [DataRow(ToolType.BuiltIn, "create_record", "create", DisplayName = "Built-in: create_record -> create")] + [DataRow(ToolType.BuiltIn, "update_record", "update", DisplayName = "Built-in: update_record -> update")] + [DataRow(ToolType.BuiltIn, "delete_record", "delete", DisplayName = "Built-in: delete_record -> delete")] + [DataRow(ToolType.BuiltIn, "describe_entities", "describe", DisplayName = "Built-in: describe_entities -> describe")] + [DataRow(ToolType.BuiltIn, "execute_entity", "execute", DisplayName = "Built-in: execute_entity -> execute")] + [DataRow(ToolType.BuiltIn, "unknown_builtin", "execute", DisplayName = "Built-in: unknown -> execute (fallback)")] + // Custom tools always return "execute" + [DataRow(ToolType.Custom, "get_book", "execute", DisplayName = "Custom: get_book -> execute (stored proc)")] + [DataRow(ToolType.Custom, "read_users", "execute", DisplayName = "Custom: read_users -> execute (ignore name)")] + [DataRow(ToolType.Custom, "custom_proc", "execute", DisplayName = "Custom: custom_proc -> execute")] + public void InferOperationFromTool_ReturnsCorrectOperation(ToolType toolType, string toolName, string expectedOperation) + { + IMcpTool tool = new MockMcpTool(CreateToolResult(), toolType); + Assert.AreEqual(expectedOperation, McpTelemetryHelper.InferOperationFromTool(tool, toolName)); + } + + #endregion + + #region MapExceptionToErrorCode + + /// + /// Test that MapExceptionToErrorCode returns the correct error code for each exception type. + /// + [DataTestMethod] + [DataRow("OperationCanceledException", McpTelemetryErrorCodes.OPERATION_CANCELLED)] + [DataRow("UnauthorizedAccessException", McpTelemetryErrorCodes.AUTHORIZATION_FAILED)] + [DataRow("ArgumentException", McpTelemetryErrorCodes.INVALID_REQUEST)] + [DataRow("InvalidOperationException", McpTelemetryErrorCodes.EXECUTION_FAILED)] + [DataRow("Exception", McpTelemetryErrorCodes.EXECUTION_FAILED)] + public void MapExceptionToErrorCode_ReturnsCorrectCode(string exceptionTypeName, string expectedErrorCode) + { + Exception ex = CreateExceptionByTypeName(exceptionTypeName); + Assert.AreEqual(expectedErrorCode, McpTelemetryHelper.MapExceptionToErrorCode(ex)); + } + + #endregion + + #region ExecuteWithTelemetryAsync + + /// + /// Test that ExecuteWithTelemetryAsync sets Ok status and correct operation for all built-in DML tools. + /// + [DataTestMethod] + [DataRow("read_records", "read", DisplayName = "read_records -> read operation")] + [DataRow("create_record", "create", DisplayName = "create_record -> create operation")] + [DataRow("update_record", "update", DisplayName = "update_record -> update operation")] + [DataRow("delete_record", "delete", DisplayName = "delete_record -> delete operation")] + [DataRow("describe_entities", "describe", DisplayName = "describe_entities -> describe operation")] + [DataRow("execute_entity", "execute", DisplayName = "execute_entity -> execute operation")] + public async Task ExecuteWithTelemetryAsync_SetsOkStatusAndCorrectOperation_ForBuiltInTools( + string toolName, string expectedOperation) + { + CallToolResult expectedResult = CreateToolResult("success"); + IMcpTool tool = new MockMcpTool(expectedResult, ToolType.BuiltIn); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, toolName, arguments: null, CreateServiceProvider(), CancellationToken.None); + + Assert.AreSame(expectedResult, result); + Activity recorded = _recordedActivities.First(); + Assert.AreEqual(ActivityStatusCode.Ok, recorded.Status); + Assert.AreEqual(toolName, recorded.GetTagItem("mcp.tool.name")); + Assert.AreEqual(expectedOperation, recorded.GetTagItem("dab.operation")); + } + + /// + /// Test that ExecuteWithTelemetryAsync always sets operation to "execute" for custom tools (stored procedures). + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_SetsExecuteOperation_ForCustomTools() + { + CallToolResult expectedResult = CreateToolResult("success"); + IMcpTool tool = new MockMcpTool(expectedResult, ToolType.Custom); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "get_book", arguments: null, CreateServiceProvider(), CancellationToken.None); + + Assert.AreSame(expectedResult, result); + Activity recorded = _recordedActivities.First(); + Assert.AreEqual(ActivityStatusCode.Ok, recorded.Status); + Assert.AreEqual("get_book", recorded.GetTagItem("mcp.tool.name")); + Assert.AreEqual("execute", recorded.GetTagItem("dab.operation")); + } + + /// + /// Test that ExecuteWithTelemetryAsync sets Error status when tool returns IsError=true. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_SetsErrorStatus_WhenToolReturnsIsError() + { + CallToolResult errorResult = CreateToolResult("error occurred", isError: true); + IMcpTool tool = new MockMcpTool(errorResult, ToolType.BuiltIn); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "create_record", arguments: null, CreateServiceProvider(), CancellationToken.None); + + Assert.AreSame(errorResult, result); + Activity recorded = _recordedActivities.First(); + Assert.AreEqual(ActivityStatusCode.Error, recorded.Status); + Assert.AreEqual(true, recorded.GetTagItem("mcp.tool.error")); + } + + /// + /// Test that ExecuteWithTelemetryAsync records exception and re-throws when tool throws. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_RecordsExceptionAndRethrows_WhenToolThrows() + { + InvalidOperationException expectedException = new("tool exploded"); + IMcpTool tool = new MockMcpTool(expectedException, ToolType.BuiltIn); + + InvalidOperationException thrownEx = await Assert.ThrowsExceptionAsync( + () => McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "delete_record", arguments: null, CreateServiceProvider(), CancellationToken.None)); + + Assert.AreEqual("tool exploded", thrownEx.Message); + + Activity recorded = _recordedActivities.First(); + Assert.AreEqual(ActivityStatusCode.Error, recorded.Status); + Assert.AreEqual("InvalidOperationException", recorded.GetTagItem("error.type")); + Assert.AreEqual(McpTelemetryErrorCodes.EXECUTION_FAILED, recorded.GetTagItem("error.code")); + + ActivityEvent? exceptionEvent = recorded.Events.FirstOrDefault(e => e.Name == "exception"); + Assert.IsNotNull(exceptionEvent, "Exception event should be recorded"); + } + + #endregion + + #region Test Mocks + + /// + /// A minimal mock IMcpTool for testing ExecuteWithTelemetryAsync. + /// Returns a predetermined result or throws a predetermined exception. + /// + private class MockMcpTool : IMcpTool + { + private readonly CallToolResult? _result; + private readonly Exception? _exception; + + public MockMcpTool(CallToolResult result, ToolType toolType = ToolType.BuiltIn) + { + _result = result; + ToolType = toolType; + } + + public MockMcpTool(Exception exception, ToolType toolType = ToolType.BuiltIn) + { + _exception = exception; + ToolType = toolType; + } + + public ToolType ToolType { get; } + + public Tool GetToolMetadata() => new() { Name = "mock_tool", Description = "Mock tool for testing" }; + + public Task ExecuteAsync(JsonDocument? arguments, IServiceProvider serviceProvider, CancellationToken cancellationToken = default) + { + if (_exception != null) + { + throw _exception; + } + + return Task.FromResult(_result!); + } + } + + #endregion + } +} diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs new file mode 100644 index 0000000000..4da3266271 --- /dev/null +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Core.Parsers; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Test class for RequestParser utility methods. + /// Specifically tests the ExtractRawQueryParameter method which preserves + /// URL encoding for special characters in query parameters. + /// + [TestClass] + public class RequestParserUnitTests + { + /// + /// Tests that ExtractRawQueryParameter correctly extracts URL-encoded + /// parameter values, preserving special characters like ampersand (&). + /// + [DataTestMethod] + [DataRow("?$filter=region%20eq%20%27filter%20%26%20test%27", "$filter", "region%20eq%20%27filter%20%26%20test%27", DisplayName = "Extract filter with encoded ampersand (&)")] + [DataRow("?$filter=title%20eq%20%27A%20%26%20B%27&$select=id", "$filter", "title%20eq%20%27A%20%26%20B%27", DisplayName = "Extract filter with ampersand and other params")] + [DataRow("?$select=id&$filter=name%20eq%20%27test%27", "$filter", "name%20eq%20%27test%27", DisplayName = "Extract filter when not first parameter")] + [DataRow("?$orderby=name%20asc", "$orderby", "name%20asc", DisplayName = "Extract orderby parameter")] + [DataRow("?param1=value1¶m2=value%26with%26ampersands", "param2", "value%26with%26ampersands", DisplayName = "Extract parameter with multiple ampersands")] + [DataRow("$filter=title%20eq%20%27test%27", "$filter", "title%20eq%20%27test%27", DisplayName = "Extract without leading question mark")] + [DataRow("?$filter=", "$filter", "", DisplayName = "Extract empty filter value")] + [DataRow("?$filter=name%20eq%20%27test%3D123%27", "$filter", "name%20eq%20%27test%3D123%27", DisplayName = "Extract filter with encoded equals sign (=)")] + [DataRow("?$filter=url%20eq%20%27http%3A%2F%2Fexample.com%3Fkey%3Dvalue%27", "$filter", "url%20eq%20%27http%3A%2F%2Fexample.com%3Fkey%3Dvalue%27", DisplayName = "Extract filter with encoded URL (: / ?)")] + [DataRow("?$filter=text%20eq%20%27A%2BB%27", "$filter", "text%20eq%20%27A%2BB%27", DisplayName = "Extract filter with encoded plus sign (+)")] + [DataRow("?$filter=value%20eq%20%2750%25%27", "$filter", "value%20eq%20%2750%25%27", DisplayName = "Extract filter with encoded percent sign (%)")] + [DataRow("?$filter=tag%20eq%20%27%23hashtag%27", "$filter", "tag%20eq%20%27%23hashtag%27", DisplayName = "Extract filter with encoded hash (#)")] + [DataRow("?$filter=expr%20eq%20%27a%3Cb%3Ed%27", "$filter", "expr%20eq%20%27a%3Cb%3Ed%27", DisplayName = "Extract filter with encoded less-than and greater-than (< >)")] + public void ExtractRawQueryParameter_PreservesEncoding(string queryString, string parameterName, string expectedValue) + { + // Call the internal method directly (no reflection needed) + string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + + Assert.AreEqual(expectedValue, result, + $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); + } + + /// + /// Tests that ExtractRawQueryParameter returns null when parameter is not found. + /// + [DataTestMethod] + [DataRow("?$filter=test", "$orderby", DisplayName = "Parameter not in query string")] + [DataRow("", "$filter", DisplayName = "Empty query string")] + [DataRow(null, "$filter", DisplayName = "Null query string")] + [DataRow("?otherParam=value", "$filter", DisplayName = "Different parameter")] + public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string? queryString, string parameterName) + { + // Call the internal method directly (no reflection needed) + string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + + Assert.IsNull(result, + $"Expected null but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); + } + + /// + /// Tests that ExtractRawQueryParameter handles edge cases correctly: + /// - Duplicate parameters (returns first occurrence) + /// - Case-insensitive parameter name matching + /// - Malformed query strings with unencoded ampersands + /// + [DataTestMethod] + [DataRow("?$filter=value&$filter=anothervalue", "$filter", "value", DisplayName = "Multiple same parameters - returns first")] + [DataRow("?$FILTER=value", "$filter", "value", DisplayName = "Case insensitive parameter matching")] + [DataRow("?param=value1&value2", "param", "value1", DisplayName = "Value with unencoded ampersand after parameter")] + public void ExtractRawQueryParameter_HandlesEdgeCases(string queryString, string parameterName, string expectedValue) + { + // Call the internal method directly (no reflection needed) + string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + + Assert.AreEqual(expectedValue, result, + $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); + } + } +} diff --git a/src/Service.Tests/UnitTests/RestServiceUnitTests.cs b/src/Service.Tests/UnitTests/RestServiceUnitTests.cs index 1fa1a276ad..3f296b4403 100644 --- a/src/Service.Tests/UnitTests/RestServiceUnitTests.cs +++ b/src/Service.Tests/UnitTests/RestServiceUnitTests.cs @@ -100,6 +100,73 @@ public void ErrorForInvalidRouteAndPathToParseTest(string route, #endregion + #region Sub-directory Path Routing Tests + + /// + /// Tests that sub-directory entity paths are correctly resolved. + /// + [DataTestMethod] + [DataRow("api/shopping-cart/item", "/api", "shopping-cart/item", "ShoppingCartItem", "")] + [DataRow("api/shopping-cart/item/id/123", "/api", "shopping-cart/item", "ShoppingCartItem", "id/123")] + [DataRow("api/invoice/item/categoryid/1/pieceid/2", "/api", "invoice/item", "InvoiceItem", "categoryid/1/pieceid/2")] + public void SubDirectoryPathRoutingTest( + string route, + string restPath, + string entityPath, + string expectedEntityName, + string expectedPrimaryKeyRoute) + { + InitializeTestWithEntityPath(restPath, entityPath, expectedEntityName); + string routeAfterPathBase = _restService.GetRouteAfterPathBase(route); + (string actualEntityName, string actualPrimaryKeyRoute) = + _restService.GetEntityNameAndPrimaryKeyRouteFromRoute(routeAfterPathBase); + Assert.AreEqual(expectedEntityName, actualEntityName); + Assert.AreEqual(expectedPrimaryKeyRoute, actualPrimaryKeyRoute); + } + + /// + /// Tests longest-prefix matching: when both "cart" and "cart/item" are valid entity paths, + /// a request to "/cart/item/id/123" should match "cart/item" (longest match wins). + /// + [TestMethod] + public void LongestPrefixMatchingTest() + { + InitializeTestWithMultipleEntityPaths("/api", new Dictionary + { + { "cart", "CartEntity" }, + { "cart/item", "CartItemEntity" } + }); + + string routeAfterPathBase = _restService.GetRouteAfterPathBase("api/cart/item/id/123"); + (string actualEntityName, string actualPrimaryKeyRoute) = + _restService.GetEntityNameAndPrimaryKeyRouteFromRoute(routeAfterPathBase); + + // Should match "cart/item" (longest), not "cart" (shortest) + Assert.AreEqual("CartItemEntity", actualEntityName); + Assert.AreEqual("id/123", actualPrimaryKeyRoute); + } + + /// + /// Tests that when only shorter path exists, it matches correctly. + /// + [TestMethod] + public void SinglePathMatchingTest() + { + InitializeTestWithMultipleEntityPaths("/api", new Dictionary + { + { "cart", "CartEntity" } + }); + + string routeAfterPathBase = _restService.GetRouteAfterPathBase("api/cart/id/123"); + (string actualEntityName, string actualPrimaryKeyRoute) = + _restService.GetEntityNameAndPrimaryKeyRouteFromRoute(routeAfterPathBase); + + Assert.AreEqual("CartEntity", actualEntityName); + Assert.AreEqual("id/123", actualPrimaryKeyRoute); + } + + #endregion + #region Helper Functions /// @@ -108,6 +175,47 @@ public void ErrorForInvalidRouteAndPathToParseTest(string route, /// /// path to return from mocked config. public static void InitializeTest(string restRoutePrefix, string entityName) + { + InitializeTestWithEntityPaths(restRoutePrefix, new Dictionary { { entityName, entityName } }); + } + + /// + /// Needed for the callback that is required + /// to make use of out parameter with mocking. + /// Without use of delegate the out param will + /// not be populated with the correct value. + /// This delegate is for the callback used + /// with the mocked MetadataProvider. + /// + /// The entity path. + /// Name of entity. + delegate void metaDataCallback(string entityPath, out string entity); + + /// + /// Initializes test with a sub-directory entity path. + /// + /// REST path prefix (e.g., "/api"). + /// Entity path with sub-directories (e.g., "shopping-cart/item"). + /// Name of the entity. + public static void InitializeTestWithEntityPath(string restRoutePrefix, string entityPath, string entityName) + { + InitializeTestWithEntityPaths(restRoutePrefix, new Dictionary { { entityPath, entityName } }); + } + + /// + /// Initializes test with multiple entity paths for testing overlapping path scenarios. + /// + /// REST path prefix (e.g., "/api"). + /// Dictionary mapping entity paths to entity names. + public static void InitializeTestWithMultipleEntityPaths(string restRoutePrefix, Dictionary entityPaths) + { + InitializeTestWithEntityPaths(restRoutePrefix, entityPaths); + } + + /// + /// Core helper to initialize REST Service with specified entity path mappings. + /// + private static void InitializeTestWithEntityPaths(string restRoutePrefix, Dictionary entityPaths) { RuntimeConfig mockConfig = new( Schema: "", @@ -147,7 +255,10 @@ public static void InitializeTest(string restRoutePrefix, string entityName) queryManagerFactory.Setup(x => x.GetQueryExecutor(It.IsAny())).Returns(queryExecutor); RuntimeConfig loadedConfig = provider.GetConfig(); - loadedConfig.TryAddEntityPathNameToEntityName(entityName, entityName); + foreach (KeyValuePair mapping in entityPaths) + { + loadedConfig.TryAddEntityPathNameToEntityName(mapping.Key, mapping.Value); + } Mock sqlMetadataProvider = new(); Mock authorizationService = new(); @@ -195,18 +306,6 @@ public static void InitializeTest(string restRoutePrefix, string entityName) provider, requestValidator); } - - /// - /// Needed for the callback that is required - /// to make use of out parameter with mocking. - /// Without use of delegate the out param will - /// not be populated with the correct value. - /// This delegate is for the callback used - /// with the mocked MetadataProvider. - /// - /// The entity path. - /// Name of entity. - delegate void metaDataCallback(string entityPath, out string entity); #endregion } } diff --git a/src/Service.Tests/UnitTests/SqlMetadataProviderUnitTests.cs b/src/Service.Tests/UnitTests/SqlMetadataProviderUnitTests.cs index 8b4ed68f60..4c5782b4ca 100644 --- a/src/Service.Tests/UnitTests/SqlMetadataProviderUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlMetadataProviderUnitTests.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Data.Common; using System.IO; +using System.IO.Abstractions; using System.Net; using System.Text.Json.Nodes; using System.Threading.Tasks; @@ -129,8 +130,13 @@ public void CheckTablePrefix(string schemaName, string tableName, string expecte queryManagerFactory.Setup(x => x.GetQueryBuilder(It.IsAny())).Returns(queryBuilder); queryManagerFactory.Setup(x => x.GetQueryExecutor(It.IsAny())).Returns(queryExecutor.Object); + IFileSystem fileSystem = new FileSystem(); + ILogger validatorLogger = new Mock>().Object; + RuntimeConfigValidator runtimeConfigValidator = new(runtimeConfigProvider, fileSystem, validatorLogger); + SqlMetadataProvider provider = new MsSqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName); @@ -215,11 +221,15 @@ private static async Task CheckExceptionForBadConnectionStringHelperAsync(string queryManagerFactory.Setup(x => x.GetQueryBuilder(It.IsAny())).Returns(_queryBuilder); queryManagerFactory.Setup(x => x.GetQueryExecutor(It.IsAny())).Returns(_queryExecutor); + IFileSystem fileSystem = new FileSystem(); + Mock> loggerValidator = new(); + RuntimeConfigValidator runtimeConfigValidator = new(runtimeConfigProvider, fileSystem, loggerValidator.Object); + ISqlMetadataProvider sqlMetadataProvider = databaseType switch { - TestCategory.MSSQL => new MsSqlMetadataProvider(runtimeConfigProvider, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), - TestCategory.MYSQL => new MySqlMetadataProvider(runtimeConfigProvider, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), - TestCategory.POSTGRESQL => new PostgreSqlMetadataProvider(runtimeConfigProvider, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), + TestCategory.MSSQL => new MsSqlMetadataProvider(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), + TestCategory.MYSQL => new MySqlMetadataProvider(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), + TestCategory.POSTGRESQL => new PostgreSqlMetadataProvider(runtimeConfigProvider, runtimeConfigValidator, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName), _ => throw new ArgumentException($"Invalid database type: {databaseType}") }; @@ -480,8 +490,13 @@ public async Task ValidateExceptionForInvalidResultFieldNames(string invalidFiel queryManagerFactory.Setup(x => x.GetQueryBuilder(It.IsAny())).Returns(_queryBuilder); queryManagerFactory.Setup(x => x.GetQueryExecutor(It.IsAny())).Returns(mockQueryExecutor.Object); + IFileSystem fileSystem = new FileSystem(); + Mock> loggerValidator = new(); + RuntimeConfigValidator runtimeConfigValidator = new(runtimeConfigProvider, fileSystem, loggerValidator.Object); + ISqlMetadataProvider sqlMetadataProvider = new MsSqlMetadataProvider( runtimeConfigProvider, + runtimeConfigValidator, queryManagerFactory.Object, sqlMetadataLogger, dataSourceName); @@ -588,5 +603,66 @@ private static async Task SetupTestFixtureAndInferMetadata() await ResetDbStateAsync(); await _sqlMetadataProvider.InitializeAsync(); } + + /// + /// Ensures that the query that returns the tables that will be generated + /// into entities from the autoentities configuration returns the expected result. + /// + [DataTestMethod, TestCategory(TestCategory.MSSQL)] + [DataRow(new string[] { "dbo.%book%" }, new string[] { }, "{schema}.{object}.books", new string[] { "book" }, "")] + [DataRow(new string[] { "dbo.%publish%" }, new string[] { }, "{schema}.{object}", new string[] { "publish" }, "")] + [DataRow(new string[] { "dbo.%book%" }, new string[] { "dbo.%books%" }, "{schema}_{object}_exclude_books", new string[] { "book" }, "books")] + [DataRow(new string[] { "dbo.%book%", "dbo.%publish%" }, new string[] { }, "{object}", new string[] { "book", "publish" }, "")] + [DataRow(new string[] { }, new string[] { "dbo.%book%" }, "{object}s", new string[] { "" }, "book")] + public async Task CheckAutoentitiesQuery(string[] include, string[] exclude, string name, string[] includeObject, string excludeObject) + { + // Arrange + DatabaseEngine = TestCategory.MSSQL; + TestHelper.SetupDatabaseEnvironment(DatabaseEngine); + RuntimeConfig runtimeConfig = SqlTestHelper.SetupRuntimeConfig(); + Autoentity autoentity = new(new AutoentityPatterns(include, exclude, name), null, null); + Dictionary dictAutoentity = new() + { + { "autoentity", autoentity } + }; + RuntimeConfig configWithAutoentity = runtimeConfig with + { + Autoentities = new RuntimeAutoentities(dictAutoentity) + }; + RuntimeConfigProvider runtimeConfigProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(configWithAutoentity); + SetUpSQLMetadataProvider(runtimeConfigProvider); + + // Act + MsSqlMetadataProvider metadataProvider = (MsSqlMetadataProvider)_sqlMetadataProvider; + JsonArray resultArray = await metadataProvider.QueryAutoentitiesAsync(autoentity); + + // Assert + Assert.IsNotNull(resultArray); + foreach (JsonObject resultObject in resultArray) + { + bool includedObjectExists = false; + foreach (string included in includeObject) + { + if (resultObject["object"].ToString().Contains(included)) + { + includedObjectExists = true; + Assert.AreNotEqual(name, resultObject["entity_name"].ToString(), "Name returned by query should not include {schema} or {object}."); + if (include.Length > 0) + { + Assert.AreEqual(expected: "dbo", actual: resultObject["schema"].ToString(), "Query does not return expected schema."); + } + + if (exclude.Length > 0) + { + Assert.IsTrue(!resultObject["object"].ToString().Contains(excludeObject), "Query returns pattern that should be excluded."); + } + } + } + + Assert.IsTrue(includedObjectExists, "Query does not return expected object."); + } + + TestHelper.UnsetAllDABEnvironmentVariables(); + } } } diff --git a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs index b3782950f9..bd1cd28b88 100644 --- a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs @@ -174,7 +174,7 @@ public async Task TestRetryPolicyExhaustingMaxAttempts() Mock httpContextAccessor = new(); DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(provider); Mock queryExecutor - = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, null); + = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, null, null); queryExecutor.Setup(x => x.ConnectionStringBuilders).Returns(new Dictionary()); @@ -283,8 +283,9 @@ public async Task TestRetryPolicySuccessfullyExecutingQueryAfterNAttempts() Mock httpContextAccessor = new(); DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(provider); EventHandler handler = null; + IOboTokenProvider oboTokenProvider = null; Mock queryExecutor - = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, handler); + = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, handler, oboTokenProvider); queryExecutor.Setup(x => x.ConnectionStringBuilders).Returns(new Dictionary()); @@ -368,7 +369,7 @@ public async Task TestHttpContextIsPopulatedWithDbExecutionTime() httpContextAccessor.Setup(x => x.HttpContext).Returns(context); DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(provider); Mock queryExecutor - = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, null); + = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, null, null); queryExecutor.Setup(x => x.ConnectionStringBuilders).Returns(new Dictionary()); @@ -419,8 +420,9 @@ public void TestInfoMessageHandlerIsAdded() Mock httpContextAccessor = new(); DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(provider); EventHandler handler = null; + IOboTokenProvider oboTokenProvider = null; Mock queryExecutor - = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, handler); + = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object, handler, oboTokenProvider); queryExecutor.Setup(x => x.ConnectionStringBuilders).Returns(new Dictionary()); @@ -669,6 +671,348 @@ public void ValidateStreamingLogicForEmptyCellsAsync() Assert.AreEqual(availableSize, (int)runtimeConfig.MaxResponseSizeMB() * 1024 * 1024); } + #region Per-User Connection Pooling Tests + + /// + /// Creates MsSqlQueryExecutor with the specified configuration for per-user connection pooling tests. + /// + /// The connection string to use. + /// Whether to enable user-delegated-auth (OBO). + /// The HttpContextAccessor mock to use. + /// A tuple containing the query executor and runtime config provider. + private static (MsSqlQueryExecutor QueryExecutor, RuntimeConfigProvider Provider) CreateQueryExecutorForPoolingTest( + string connectionString, + bool enableObo, + Mock httpContextAccessor) + { + DataSource dataSource = new( + DatabaseType: DatabaseType.MSSQL, + ConnectionString: connectionString, + Options: null) + { + UserDelegatedAuth = enableObo + ? new UserDelegatedAuthOptions( + Enabled: true, + Provider: "EntraId", + DatabaseAudience: "https://database.windows.net") + : null + }; + + RuntimeConfig mockConfig = new( + Schema: "", + DataSource: dataSource, + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null) + ), + Entities: new(new Dictionary())); + + MockFileSystem fileSystem = new(); + fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(mockConfig.ToJson())); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + + Mock>> queryExecutorLogger = new(); + DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(provider); + + MsSqlQueryExecutor queryExecutor = new(provider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object); + return (queryExecutor, provider); + } + + /// + /// Creates an HttpContextAccessor mock with the specified user claims. + /// + /// The issuer claim value, or empty string for no context. + /// The oid claim value, or empty string for no context. + /// A configured HttpContextAccessor mock. + private static Mock CreateHttpContextAccessorWithClaims(string issuer, string objectId) + { + Mock httpContextAccessor = new(); + + if (string.IsNullOrEmpty(issuer) && string.IsNullOrEmpty(objectId)) + { + httpContextAccessor.Setup(x => x.HttpContext).Returns(value: null); + } + else + { + DefaultHttpContext context = new(); + System.Security.Claims.ClaimsIdentity identity = new("TestAuth"); + if (!string.IsNullOrEmpty(issuer)) + { + identity.AddClaim(new System.Security.Claims.Claim("iss", issuer)); + } + + if (!string.IsNullOrEmpty(objectId)) + { + identity.AddClaim(new System.Security.Claims.Claim("oid", objectId)); + } + + context.User = new System.Security.Claims.ClaimsPrincipal(identity); + httpContextAccessor.Setup(x => x.HttpContext).Returns(context); + } + + return httpContextAccessor; + } + + /// + /// Test that the Pooling property from the connection string is never modified by DAB, + /// regardless of whether OBO is enabled or disabled. If Pooling=true, it stays true. + /// If Pooling=false, it stays false. DAB respects the user's explicit configuration. + /// + [DataTestMethod, TestCategory(TestCategory.MSSQL)] + [DataRow(true, true, DisplayName = "OBO enabled, Pooling=true stays true")] + [DataRow(true, false, DisplayName = "OBO enabled, Pooling=false stays false")] + [DataRow(false, true, DisplayName = "OBO disabled, Pooling=true stays true")] + [DataRow(false, false, DisplayName = "OBO disabled, Pooling=false stays false")] + public void TestPoolingPropertyIsNeverModified(bool enableObo, bool poolingValue) + { + // Arrange + Mock httpContextAccessor = new(); + string connectionString = $"Server=localhost;Database=test;Pooling={poolingValue};"; + + // Act + (MsSqlQueryExecutor queryExecutor, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: connectionString, + enableObo: enableObo, + httpContextAccessor: httpContextAccessor); + + SqlConnectionStringBuilder connBuilder = new( + queryExecutor.ConnectionStringBuilders[provider.GetConfig().DefaultDataSourceName].ConnectionString); + + // Assert - Pooling property should be unchanged from the original connection string + Assert.AreEqual(poolingValue, connBuilder.Pooling, + $"Pooling={poolingValue} should remain unchanged when OBO is {(enableObo ? "enabled" : "disabled")}"); + } + + /// + /// Test that when OBO is enabled and user claims are present, CreateConnection returns + /// a connection string with a user-specific Application Name containing the pool hash. + /// + [TestMethod, TestCategory(TestCategory.MSSQL)] + public void TestOboWithUserClaims_ConnectionStringHasUserSpecificAppName() + { + // Arrange & Act + Mock httpContextAccessor = CreateHttpContextAccessorWithClaims( + issuer: "https://login.microsoftonline.com/tenant-id/v2.0", + objectId: "user-object-id-12345"); + + (MsSqlQueryExecutor queryExecutor, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: "Server=localhost;Database=test;Application Name=TestApp;", + enableObo: true, + httpContextAccessor: httpContextAccessor); + + SqlConnection conn = queryExecutor.CreateConnection(provider.GetConfig().DefaultDataSourceName); + SqlConnectionStringBuilder connBuilder = new(conn.ConnectionString); + + // Assert - Application Name should have hash prefix followed by the base name + // Format: {hash}|{user-custom-appname} + // Hash is 16 bytes truncated SHA256, Base64-encoded to ~22 chars (16 bytes * 4/3 = 21.3) + // Hash is placed first to ensure it's never truncated if app name exceeds 128 chars + Assert.IsTrue(connBuilder.ApplicationName.Contains("|"), + $"Application Name should contain '|' separator but was '{connBuilder.ApplicationName}'"); + Assert.IsTrue(connBuilder.ApplicationName.Contains("TestApp"), + $"Application Name should contain 'TestApp' but was '{connBuilder.ApplicationName}'"); + // Hash should be at the start (before the | separator) + // 16 bytes Base64-encoded (without padding) = ~22 characters + string hashPart = connBuilder.ApplicationName.Split('|')[0]; + Assert.IsTrue(hashPart.Length >= 20 && hashPart.Length <= 25, + $"Hash prefix should be ~22 chars (16 bytes Base64) but was {hashPart.Length} chars: '{hashPart}'"); + Assert.IsTrue(connBuilder.Pooling, "Pooling should be enabled"); + } + + /// + /// Test that when the base Application Name + hash prefix exceeds 128 characters, + /// the base app name is truncated (not the hash) to fit within SQL Server's limit. + /// This verifies the hash-first format ensures pool isolation even with long app names. + /// + [TestMethod, TestCategory(TestCategory.MSSQL)] + public void TestOboWithLongAppName_TruncatesToFitWithinLimit() + { + // Arrange - Create an Application Name that would exceed 128 chars when hash is added + // Hash prefix is ~22 chars + "|" = 23 chars, so base app name of 120 chars would exceed limit + string longAppName = new('A', 120); // 120 chars, plus 23 for hash = 143 total + + Mock httpContextAccessor = CreateHttpContextAccessorWithClaims( + issuer: "https://login.microsoftonline.com/tenant-id/v2.0", + objectId: "user-object-id-12345"); + + (MsSqlQueryExecutor queryExecutor, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: $"Server=localhost;Database=test;Application Name={longAppName};", + enableObo: true, + httpContextAccessor: httpContextAccessor); + + // Act + SqlConnection conn = queryExecutor.CreateConnection(provider.GetConfig().DefaultDataSourceName); + SqlConnectionStringBuilder connBuilder = new(conn.ConnectionString); + + // Assert - Application Name should be truncated to 128 chars max + Assert.IsTrue(connBuilder.ApplicationName.Length <= 128, + $"Application Name should be <= 128 chars but was {connBuilder.ApplicationName.Length} chars"); + + // Hash should still be at the start and complete (not truncated) + string[] parts = connBuilder.ApplicationName.Split('|'); + Assert.AreEqual(2, parts.Length, "Application Name should have exactly one '|' separator"); + + string hashPart = parts[0]; + Assert.IsTrue(hashPart.Length >= 20 && hashPart.Length <= 25, + $"Hash prefix should be ~22 chars (16 bytes Base64) but was {hashPart.Length} chars: '{hashPart}'"); + + // The base app name should be truncated, not the hash + string truncatedAppName = parts[1]; + Assert.IsTrue(truncatedAppName.Length < longAppName.Length, + $"Base app name should be truncated from {longAppName.Length} chars but was {truncatedAppName.Length} chars"); + Assert.IsTrue(truncatedAppName.All(c => c == 'A'), + "Truncated app name should contain only the original characters (no corruption)"); + } + + /// + /// Test that different users get different pool hashes (different Application Names). + /// + [TestMethod, TestCategory(TestCategory.MSSQL)] + public void TestObo_DifferentUsersGetDifferentPoolHashes() + { + // Arrange & Act - User 1 + Mock httpContextAccessor1 = CreateHttpContextAccessorWithClaims( + issuer: "https://login.microsoftonline.com/tenant-id/v2.0", + objectId: "user1-oid-aaaa"); + + (MsSqlQueryExecutor queryExecutor1, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: "Server=localhost;Database=test;Application Name=DAB;", + enableObo: true, + httpContextAccessor: httpContextAccessor1); + + SqlConnection conn1 = queryExecutor1.CreateConnection(provider.GetConfig().DefaultDataSourceName); + SqlConnectionStringBuilder connBuilder1 = new(conn1.ConnectionString); + + // Arrange & Act - User 2 + Mock httpContextAccessor2 = CreateHttpContextAccessorWithClaims( + issuer: "https://login.microsoftonline.com/tenant-id/v2.0", + objectId: "user2-oid-bbbb"); + + (MsSqlQueryExecutor queryExecutor2, RuntimeConfigProvider provider2) = CreateQueryExecutorForPoolingTest( + connectionString: "Server=localhost;Database=test;Application Name=DAB;", + enableObo: true, + httpContextAccessor: httpContextAccessor2); + + SqlConnection conn2 = queryExecutor2.CreateConnection(provider2.GetConfig().DefaultDataSourceName); + SqlConnectionStringBuilder connBuilder2 = new(conn2.ConnectionString); + + // Assert - both should have hash prefix and different hashes + // Format: {hash}|{appname} - hash is first to prevent truncation + Assert.IsTrue(connBuilder1.ApplicationName.Contains("|"), "User 1 should have hash prefix"); + Assert.IsTrue(connBuilder2.ApplicationName.Contains("|"), "User 2 should have hash prefix"); + Assert.AreNotEqual(connBuilder1.ApplicationName, connBuilder2.ApplicationName, + "Different users should have different Application Names (different pool hashes)"); + } + + /// + /// Test that when no user context is present (e.g., startup), connection string uses base Application Name. + /// + [TestMethod, TestCategory(TestCategory.MSSQL)] + public void TestOboNoUserContext_UsesBaseConnectionString() + { + // Arrange & Act + Mock httpContextAccessor = CreateHttpContextAccessorWithClaims(issuer: string.Empty, objectId: string.Empty); + + (MsSqlQueryExecutor queryExecutor, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: "Server=localhost;Database=test;Application Name=BaseApp;", + enableObo: true, + httpContextAccessor: httpContextAccessor); + + SqlConnection conn = queryExecutor.CreateConnection(provider.GetConfig().DefaultDataSourceName); + SqlConnectionStringBuilder connBuilder = new(conn.ConnectionString); + + // Assert - without user context, should use base Application Name (no hash prefix) + // Note: The actual format includes version suffix, e.g., "BaseApp,dab_oss_2.0.0" + Assert.IsTrue(connBuilder.ApplicationName.StartsWith("BaseApp"), + $"Without user context, Application Name should start with 'BaseApp' but was '{connBuilder.ApplicationName}'"); + // When no user context, the app name should NOT have the hash prefix pattern + // (hash prefix is 16 bytes Base64-encoded = ~22 chars, followed by |) + string[] parts = connBuilder.ApplicationName.Split('|'); + bool hasHashPrefix = parts.Length > 1 && parts[0].Length >= 20 && parts[0].Length <= 25; + Assert.IsFalse(hasHashPrefix, + $"Without user context, Application Name should not have hash prefix but was '{connBuilder.ApplicationName}'"); + } + + /// + /// Test that when OBO is enabled and a user is authenticated but missing required claims + /// (iss or oid/sub), CreateConnection throws DataApiBuilderException with OboAuthenticationFailure. + /// This fail-safe behavior prevents cross-user connection pool contamination. + /// + [DataTestMethod, TestCategory(TestCategory.MSSQL)] + [DataRow("https://login.microsoftonline.com/tenant/v2.0", null, "oid/sub", + DisplayName = "Authenticated user with iss but missing oid/sub throws OboAuthenticationFailure")] + [DataRow(null, "user-object-id", "iss", + DisplayName = "Authenticated user with oid but missing iss throws OboAuthenticationFailure")] + [DataRow(null, null, "iss and oid/sub", + DisplayName = "Authenticated user with no claims throws OboAuthenticationFailure")] + public void TestOboEnabled_AuthenticatedUserMissingClaims_ThrowsException( + string? issuer, + string? objectId, + string missingClaimDescription) + { + // Arrange - Create an authenticated HttpContext with incomplete claims + Mock httpContextAccessor = CreateHttpContextAccessorWithAuthenticatedUserMissingClaims( + issuer: issuer, + objectId: objectId); + + (MsSqlQueryExecutor queryExecutor, RuntimeConfigProvider provider) = CreateQueryExecutorForPoolingTest( + connectionString: "Server=localhost;Database=test;Application Name=TestApp;", + enableObo: true, + httpContextAccessor: httpContextAccessor); + + // Act & Assert - CreateConnection should throw DataApiBuilderException + DataApiBuilderException exception = Assert.ThrowsException(() => + { + queryExecutor.CreateConnection(provider.GetConfig().DefaultDataSourceName); + }); + + Assert.AreEqual(HttpStatusCode.Unauthorized, exception.StatusCode, + $"Expected Unauthorized status code when missing {missingClaimDescription}"); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.OboAuthenticationFailure, exception.SubStatusCode, + $"Expected OboAuthenticationFailure sub-status code when missing {missingClaimDescription}"); + Assert.IsTrue(exception.Message.Contains("iss") && exception.Message.Contains("oid"), + $"Exception message should mention required claims. Actual: {exception.Message}"); + } + + /// + /// Creates an HttpContextAccessor mock with an authenticated user that has incomplete claims. + /// Used to test fail-safe behavior when OBO is enabled but required claims are missing. + /// + /// The issuer claim value, or null to omit. + /// The oid claim value, or null to omit. + /// A configured HttpContextAccessor mock with authenticated user. + private static Mock CreateHttpContextAccessorWithAuthenticatedUserMissingClaims( + string? issuer, + string? objectId) + { + Mock httpContextAccessor = new(); + DefaultHttpContext context = new(); + + // Create an authenticated identity (passing authenticationType makes IsAuthenticated = true) + System.Security.Claims.ClaimsIdentity identity = new("TestAuth"); + + // Only add claims if they are provided (non-null) + if (!string.IsNullOrEmpty(issuer)) + { + identity.AddClaim(new System.Security.Claims.Claim("iss", issuer)); + } + + if (!string.IsNullOrEmpty(objectId)) + { + identity.AddClaim(new System.Security.Claims.Claim("oid", objectId)); + } + + context.User = new System.Security.Claims.ClaimsPrincipal(identity); + httpContextAccessor.Setup(x => x.HttpContext).Returns(context); + + return httpContextAccessor; + } + + #endregion + [TestCleanup] public void CleanupAfterEachTest() { diff --git a/src/Service.Tests/UnitTests/StartupTests.cs b/src/Service.Tests/UnitTests/StartupTests.cs new file mode 100644 index 0000000000..1e90915cb7 --- /dev/null +++ b/src/Service.Tests/UnitTests/StartupTests.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using StackExchange.Redis; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass] + public class StartupTests + { + [DataTestMethod] + [DataRow("localhost:6379", false, DisplayName = "Localhost endpoint without password should NOT use Entra auth.")] + [DataRow("127.0.0.1:6379", false, DisplayName = "IPv4 loopback without password should NOT use Entra auth.")] + [DataRow("[::1]:6379", false, DisplayName = "IPv6 loopback without password should NOT use Entra auth.")] + [DataRow("redis.example.com:6380", true, DisplayName = "Remote endpoint without password SHOULD use Entra auth.")] + [DataRow("redis.example.com:6380,password=secret", false, DisplayName = "Presence of password should NOT use Entra auth, even for remote endpoints.")] + [DataRow("localhost:6379,redis.example.com:6380", true, DisplayName = "Mixed endpoints (including remote) without password SHOULD use Entra auth.")] + [DataRow("localhost:6379,password=secret", false, DisplayName = "Localhost with password should NOT use Entra auth.")] + public void ShouldUseEntraAuthForRedis(string connectionString, bool expectedUseEntraAuth) + { + // Arrange + var options = ConfigurationOptions.Parse(connectionString); + + // Act + bool result = Startup.ShouldUseEntraAuthForRedis(options); + + // Assert + Assert.AreEqual(expectedUseEntraAuth, result); + } + } +} diff --git a/src/Service.Tests/dab-config.DwSql.json b/src/Service.Tests/dab-config.DwSql.json index bb0d6a5893..a5eec531b8 100644 --- a/src/Service.Tests/dab-config.DwSql.json +++ b/src/Service.Tests/dab-config.DwSql.json @@ -2,7 +2,7 @@ "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json", "data-source": { "database-type": "dwsql", - "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=5;", + "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=30;", "options": { "set-session-context": true } diff --git a/src/Service.Tests/dab-config.MsSql.json b/src/Service.Tests/dab-config.MsSql.json index e7f73702ba..157bf87b8c 100644 --- a/src/Service.Tests/dab-config.MsSql.json +++ b/src/Service.Tests/dab-config.MsSql.json @@ -2,7 +2,7 @@ "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json", "data-source": { "database-type": "mssql", - "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=5;", + "connection-string": "Server=tcp:127.0.0.1,1433;Persist Security Info=False;User ID=sa;Password=REPLACEME;MultipleActiveResultSets=False;Connection Timeout=30;", "options": { "set-session-context": true } diff --git a/src/Service/Azure.DataApiBuilder.Service.csproj b/src/Service/Azure.DataApiBuilder.Service.csproj index 5cf762ca57..d0478b83ed 100644 --- a/src/Service/Azure.DataApiBuilder.Service.csproj +++ b/src/Service/Azure.DataApiBuilder.Service.csproj @@ -64,6 +64,7 @@ + diff --git a/src/Service/Controllers/RestController.cs b/src/Service/Controllers/RestController.cs index cebd6f4463..07841fd8c3 100644 --- a/src/Service/Controllers/RestController.cs +++ b/src/Service/Controllers/RestController.cs @@ -222,6 +222,7 @@ private async Task HandleOperation( string routeAfterPathBase = _restService.GetRouteAfterPathBase(route); // Explicitly handle OpenAPI description document retrieval requests. + // Supports /openapi (superset of all roles) and /openapi/{role} (role-specific) if (string.Equals(routeAfterPathBase, OpenApiDocumentor.OPENAPI_ROUTE, StringComparison.OrdinalIgnoreCase)) { if (_openApiDocumentor.TryGetDocument(out string? document)) @@ -232,6 +233,33 @@ private async Task HandleOperation( return NotFound(); } + // Handle /openapi/{role} route for role-specific OpenAPI documents + // Only allow in Development mode for security reasons + if (routeAfterPathBase.StartsWith(OpenApiDocumentor.OPENAPI_ROUTE + "/", StringComparison.OrdinalIgnoreCase)) + { + RuntimeConfig config = _runtimeConfigProvider.GetConfig(); + if (config.Runtime?.Host?.Mode != HostMode.Development) + { + return NotFound(); + } + + string role = Uri.UnescapeDataString( + routeAfterPathBase.Substring(OpenApiDocumentor.OPENAPI_ROUTE.Length + 1)); + + // Validate role doesn't contain path separators (reject /openapi/foo/bar) + if (string.IsNullOrEmpty(role) || role.Contains('/')) + { + return NotFound(); + } + + if (_openApiDocumentor.TryGetDocumentForRole(role, out string? roleDocument)) + { + return Content(roleDocument, MediaTypeNames.Application.Json); + } + + return NotFound(); + } + (string entityName, string primaryKeyRoute) = _restService.GetEntityNameAndPrimaryKeyRouteFromRoute(routeAfterPathBase); // This activity tracks the query execution. This will create a new activity nested under the REST request activity. diff --git a/src/Service/HealthCheck/ComprehensiveHealthReportResponseWriter.cs b/src/Service/HealthCheck/ComprehensiveHealthReportResponseWriter.cs index 2555890791..5027c5e059 100644 --- a/src/Service/HealthCheck/ComprehensiveHealthReportResponseWriter.cs +++ b/src/Service/HealthCheck/ComprehensiveHealthReportResponseWriter.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.IO; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -76,8 +75,8 @@ public async Task WriteResponseAsync(HttpContext context) // Global comprehensive Health Check Enabled if (config.IsHealthEnabled) { - _healthCheckHelper.StoreIncomingRoleHeader(context); - if (!_healthCheckHelper.IsUserAllowedToAccessHealthCheck(context, config.IsDevelopmentMode(), config.AllowedRolesForHealth)) + (string roleHeader, string roleToken) = _healthCheckHelper.ReadRoleHeaders(context); + if (!_healthCheckHelper.IsUserAllowedToAccessHealthCheck(config.IsDevelopmentMode(), config.AllowedRolesForHealth, roleHeader)) { _logger.LogError("Comprehensive Health Check Report is not allowed: 403 Forbidden due to insufficient permissions."); context.Response.StatusCode = StatusCodes.Status403Forbidden; @@ -85,34 +84,33 @@ public async Task WriteResponseAsync(HttpContext context) return; } - string? response; // Check if the cache is enabled if (config.CacheTtlSecondsForHealthReport > 0) { + ComprehensiveHealthCheckReport? report = null; try { - response = await _cache.GetOrSetAsync( + report = await _cache.GetOrSetAsync( key: CACHE_KEY, - async (FusionCacheFactoryExecutionContext ctx, CancellationToken ct) => + async (FusionCacheFactoryExecutionContext ctx, CancellationToken ct) => { - string? response = await ExecuteHealthCheckAsync(config).ConfigureAwait(false); + ComprehensiveHealthCheckReport? r = await _healthCheckHelper.GetHealthCheckResponseAsync(config, roleHeader, roleToken).ConfigureAwait(false); ctx.Options.SetDuration(TimeSpan.FromSeconds(config.CacheTtlSecondsForHealthReport)); - return response; + return r; }); _logger.LogTrace($"Health check response is fetched from cache with key: {CACHE_KEY} and TTL: {config.CacheTtlSecondsForHealthReport} seconds."); } catch (Exception ex) { - response = null; // Set response to null in case of an error _logger.LogError($"Error in caching health check response: {ex.Message}"); } // Ensure cachedResponse is not null before calling WriteAsync - if (response != null) + if (report != null) { - // Return the cached or newly generated response - await context.Response.WriteAsync(response); + // Set currentRole per-request (not cached) so each caller sees their own role + await context.Response.WriteAsync(SerializeReport(report with { CurrentRole = _healthCheckHelper.GetCurrentRole(roleHeader, roleToken) })); } else { @@ -124,9 +122,9 @@ public async Task WriteResponseAsync(HttpContext context) } else { - response = await ExecuteHealthCheckAsync(config).ConfigureAwait(false); + ComprehensiveHealthCheckReport report = await _healthCheckHelper.GetHealthCheckResponseAsync(config, roleHeader, roleToken).ConfigureAwait(false); // Return the newly generated response - await context.Response.WriteAsync(response); + await context.Response.WriteAsync(SerializeReport(report with { CurrentRole = _healthCheckHelper.GetCurrentRole(roleHeader, roleToken) })); } } else @@ -139,13 +137,10 @@ public async Task WriteResponseAsync(HttpContext context) return; } - private async Task ExecuteHealthCheckAsync(RuntimeConfig config) + private string SerializeReport(ComprehensiveHealthCheckReport report) { - ComprehensiveHealthCheckReport dabHealthCheckReport = await _healthCheckHelper.GetHealthCheckResponseAsync(config); - string response = JsonSerializer.Serialize(dabHealthCheckReport, options: new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull }); - _logger.LogTrace($"Health check response writer writing status as: {dabHealthCheckReport.Status}"); - - return response; + _logger.LogTrace($"Health check response writer writing status as: {report.Status}"); + return JsonSerializer.Serialize(report, options: new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull }); } } } diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index ab19756195..2a5f6f5ddf 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -27,8 +27,6 @@ public class HealthCheckHelper // Dependencies private ILogger _logger; private HttpUtilities _httpUtility; - private string _incomingRoleHeader = string.Empty; - private string _incomingRoleToken = string.Empty; private const string TIME_EXCEEDED_ERROR_MESSAGE = "The threshold for executing the request has exceeded."; @@ -48,8 +46,10 @@ public HealthCheckHelper(ILogger logger, HttpUtilities httpUt /// Serializes the report to JSON and returns the response. /// /// RuntimeConfig + /// The effective role header for the current request. + /// The bearer token for the current request. /// This function returns the comprehensive health report after calculating the response time of each datasource, rest and graphql health queries. - public async Task GetHealthCheckResponseAsync(RuntimeConfig runtimeConfig) + public async Task GetHealthCheckResponseAsync(RuntimeConfig runtimeConfig, string roleHeader, string roleToken) { // Create a JSON response for the comprehensive health check endpoint using the provided basic health report. // If the response has already been created, it will be reused. @@ -59,13 +59,13 @@ public async Task GetHealthCheckResponseAsync(Ru UpdateVersionAndAppName(ref comprehensiveHealthCheckReport); UpdateTimestampOfResponse(ref comprehensiveHealthCheckReport); UpdateDabConfigurationDetails(ref comprehensiveHealthCheckReport, runtimeConfig); - await UpdateHealthCheckDetailsAsync(comprehensiveHealthCheckReport, runtimeConfig); + await UpdateHealthCheckDetailsAsync(comprehensiveHealthCheckReport, runtimeConfig, roleHeader, roleToken); UpdateOverallHealthStatus(ref comprehensiveHealthCheckReport); return comprehensiveHealthCheckReport; } - // Updates the incoming role header with the appropriate value from the request headers. - public void StoreIncomingRoleHeader(HttpContext httpContext) + // Reads the incoming role and token headers from the request and returns them as local values. + public (string roleHeader, string roleToken) ReadRoleHeaders(HttpContext httpContext) { StringValues clientRoleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; StringValues clientTokenHeader = httpContext.Request.Headers[AuthenticationOptions.CLIENT_PRINCIPAL_HEADER]; @@ -75,27 +75,31 @@ public void StoreIncomingRoleHeader(HttpContext httpContext) throw new ArgumentException("Multiple values for the client role or token header are not allowed."); } - // Role Header is not present in the request, set it to anonymous. - if (clientRoleHeader.Count == 1) - { - _incomingRoleHeader = clientRoleHeader.ToString().ToLowerInvariant(); - } + string roleHeader = clientRoleHeader.Count == 1 ? clientRoleHeader.ToString().ToLowerInvariant() : string.Empty; + string roleToken = clientTokenHeader.Count == 1 ? clientTokenHeader.ToString() : string.Empty; + return (roleHeader, roleToken); + } - if (clientTokenHeader.Count == 1) - { - _incomingRoleToken = clientTokenHeader.ToString(); - } + // Returns the effective role for the current request. + // Falls back to "authenticated" if a bearer token is present, or "anonymous" otherwise. + public string GetCurrentRole(string roleHeader, string roleToken) + { + return !string.IsNullOrEmpty(roleHeader) + ? roleHeader + : !string.IsNullOrEmpty(roleToken) + ? AuthorizationResolver.ROLE_AUTHENTICATED + : AuthorizationResolver.ROLE_ANONYMOUS; } /// /// Checks if the incoming request is allowed to access the health check endpoint. /// Anonymous requests are only allowed in Development Mode. /// - /// HttpContext to get the headers. - /// Compare with the HostMode of DAB + /// Compare with the HostMode of DAB /// AllowedRoles in the Runtime.Health config + /// The effective role header for the current request. /// - public bool IsUserAllowedToAccessHealthCheck(HttpContext httpContext, bool isDevelopmentMode, HashSet allowedRoles) + public bool IsUserAllowedToAccessHealthCheck(bool isDevelopmentMode, HashSet allowedRoles, string roleHeader) { if (allowedRoles == null || allowedRoles.Count == 0) { @@ -103,7 +107,7 @@ public bool IsUserAllowedToAccessHealthCheck(HttpContext httpContext, bool isDev return isDevelopmentMode; } - return allowedRoles.Contains(_incomingRoleHeader); + return allowedRoles.Contains(roleHeader); } // Updates the overall status by comparing all the internal HealthStatuses in the response. @@ -149,11 +153,11 @@ private static void UpdateDabConfigurationDetails(ref ComprehensiveHealthCheckRe } // Main function to internally call for data source and entities health check. - private async Task UpdateHealthCheckDetailsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) + private async Task UpdateHealthCheckDetailsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig, string roleHeader, string roleToken) { comprehensiveHealthCheckReport.Checks = new List(); await UpdateDataSourceHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); - await UpdateEntityHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); + await UpdateEntityHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig, roleHeader, roleToken); } // Updates the DataSource Health Check Results in the response. @@ -162,7 +166,7 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh if (comprehensiveHealthCheckReport.Checks != null && runtimeConfig.DataSource.IsDatasourceHealthEnabled) { string query = Utilities.GetDatSourceQuery(runtimeConfig.DataSource.DatabaseType); - (int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType)); + (int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType), runtimeConfig.DataSource.DatabaseType); bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < runtimeConfig.DataSource.DatasourceThresholdMs; // Add DataSource Health Check Results @@ -182,14 +186,14 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh } // Executes the DB Query and keeps track of the response time and error message. - private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory) + private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory, DatabaseType databaseType) { string? errorMessage = null; if (!string.IsNullOrEmpty(query) && !string.IsNullOrEmpty(connectionString)) { Stopwatch stopwatch = new(); stopwatch.Start(); - errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory); + errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory, databaseType); stopwatch.Stop(); return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } @@ -200,7 +204,7 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh // Updates the Entity Health Check Results in the response. // Goes through the entities one by one and executes the rest and graphql checks (if enabled). // Stored procedures are excluded from health checks because they require parameters and are not guaranteed to be deterministic. - private async Task UpdateEntityHealthCheckResultsAsync(ComprehensiveHealthCheckReport report, RuntimeConfig runtimeConfig) + private async Task UpdateEntityHealthCheckResultsAsync(ComprehensiveHealthCheckReport report, RuntimeConfig runtimeConfig, string roleHeader, string roleToken) { List> enabledEntities = runtimeConfig.Entities.Entities .Where(e => e.Value.IsEntityHealthEnabled && e.Value.Source.Type != EntitySourceType.StoredProcedure) @@ -232,7 +236,7 @@ private async Task UpdateEntityHealthCheckResultsAsync(ComprehensiveHealthCheckR Checks = new List() }; - await PopulateEntityHealthAsync(localReport, entity, runtimeConfig); + await PopulateEntityHealthAsync(localReport, entity, runtimeConfig, roleHeader, roleToken); if (localReport.Checks != null) { @@ -255,7 +259,7 @@ private async Task UpdateEntityHealthCheckResultsAsync(ComprehensiveHealthCheckR // Populates the Entity Health Check Results in the response for a particular entity. // Checks for Rest enabled and executes the rest query. // Checks for GraphQL enabled and executes the graphql query. - private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, KeyValuePair entity, RuntimeConfig runtimeConfig) + private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, KeyValuePair entity, RuntimeConfig runtimeConfig, string roleHeader, string roleToken) { // Global Rest and GraphQL Runtime Options RuntimeOptions? runtimeOptions = runtimeConfig.Runtime; @@ -274,7 +278,7 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp // The path is trimmed to remove the leading '/' character. // If the path is not present, use the entity key name as the path. string entityPath = entityValue.Rest.Path != null ? entityValue.Rest.Path.TrimStart('/') : entityKeyName; - (int, string?) response = await ExecuteRestEntityQueryAsync(runtimeConfig.RestPath, entityPath, entityValue.EntityFirst); + (int, string?) response = await ExecuteRestEntityQueryAsync(runtimeConfig.RestPath, entityPath, entityValue.EntityFirst, roleHeader, roleToken); bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < entityValue.EntityThresholdMs; // Add Entity Health Check Results @@ -296,7 +300,7 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp { comprehensiveHealthCheckReport.Checks ??= new List(); - (int, string?) response = await ExecuteGraphQlEntityQueryAsync(runtimeConfig.GraphQLPath, entityValue, entityKeyName); + (int, string?) response = await ExecuteGraphQlEntityQueryAsync(runtimeConfig.GraphQLPath, entityValue, entityKeyName, roleHeader, roleToken); bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < entityValue.EntityThresholdMs; comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry @@ -316,14 +320,14 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp } // Executes the Rest Entity Query and keeps track of the response time and error message. - private async Task<(int, string?)> ExecuteRestEntityQueryAsync(string restUriSuffix, string entityName, int first) + private async Task<(int, string?)> ExecuteRestEntityQueryAsync(string restUriSuffix, string entityName, int first, string roleHeader, string roleToken) { string? errorMessage = null; if (!string.IsNullOrEmpty(entityName)) { Stopwatch stopwatch = new(); stopwatch.Start(); - errorMessage = await _httpUtility.ExecuteRestQueryAsync(restUriSuffix, entityName, first, _incomingRoleHeader, _incomingRoleToken); + errorMessage = await _httpUtility.ExecuteRestQueryAsync(restUriSuffix, entityName, first, roleHeader, roleToken); stopwatch.Stop(); return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } @@ -332,14 +336,14 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp } // Executes the GraphQL Entity Query and keeps track of the response time and error message. - private async Task<(int, string?)> ExecuteGraphQlEntityQueryAsync(string graphqlUriSuffix, Entity entity, string entityName) + private async Task<(int, string?)> ExecuteGraphQlEntityQueryAsync(string graphqlUriSuffix, Entity entity, string entityName, string roleHeader, string roleToken) { string? errorMessage = null; if (entity != null) { Stopwatch stopwatch = new(); stopwatch.Start(); - errorMessage = await _httpUtility.ExecuteGraphQLQueryAsync(graphqlUriSuffix, entityName, entity, _incomingRoleHeader, _incomingRoleToken); + errorMessage = await _httpUtility.ExecuteGraphQLQueryAsync(graphqlUriSuffix, entityName, entity, roleHeader, roleToken); stopwatch.Stop(); return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } diff --git a/src/Service/HealthCheck/HttpUtilities.cs b/src/Service/HealthCheck/HttpUtilities.cs index 9da596ae30..2a8d7b9f3e 100644 --- a/src/Service/HealthCheck/HttpUtilities.cs +++ b/src/Service/HealthCheck/HttpUtilities.cs @@ -49,7 +49,7 @@ public HttpUtilities( } // Executes the DB query by establishing a connection to the DB. - public async Task ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory) + public async Task ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory, DatabaseType databaseType) { string? errorMessage = null; // Execute the query on DB and return the response time. @@ -65,7 +65,7 @@ public HttpUtilities( { try { - connection.ConnectionString = connectionString; + connection.ConnectionString = Utilities.NormalizeConnectionString(connectionString, databaseType, _logger); using (DbCommand command = connection.CreateCommand()) { command.CommandText = query; diff --git a/src/Service/HealthCheck/Model/ComprehensiveHealthCheckReport.cs b/src/Service/HealthCheck/Model/ComprehensiveHealthCheckReport.cs index b649a6bfc7..26a260af47 100644 --- a/src/Service/HealthCheck/Model/ComprehensiveHealthCheckReport.cs +++ b/src/Service/HealthCheck/Model/ComprehensiveHealthCheckReport.cs @@ -43,6 +43,12 @@ public record ComprehensiveHealthCheckReport [JsonPropertyName("timestamp")] public DateTime TimeStamp { get; set; } + /// + /// The current role of the user making the request (e.g., "anonymous", "authenticated"). + /// + [JsonPropertyName("currentRole")] + public string? CurrentRole { get; set; } + /// /// The configuration details of the dab service. /// diff --git a/src/Service/HealthCheck/Utilities.cs b/src/Service/HealthCheck/Utilities.cs index 290410291e..888ffbca91 100644 --- a/src/Service/HealthCheck/Utilities.cs +++ b/src/Service/HealthCheck/Utilities.cs @@ -7,6 +7,8 @@ using System.Text.Json; using Azure.DataApiBuilder.Config.ObjectModel; using Microsoft.Data.SqlClient; +using Microsoft.Extensions.Logging; +using MySqlConnector; using Npgsql; namespace Azure.DataApiBuilder.Service.HealthCheck @@ -69,5 +71,32 @@ public static string CreateHttpRestQuery(string entityName, int first) // "EntityName?$first=4" return $"/{entityName}?$first={first}"; } + + public static string NormalizeConnectionString(string connectionString, DatabaseType dbType, ILogger? logger = null) + { + try + { + switch (dbType) + { + case DatabaseType.PostgreSQL: + return new NpgsqlConnectionStringBuilder(connectionString).ToString(); + case DatabaseType.MySQL: + return new MySqlConnectionStringBuilder(connectionString).ToString(); + case DatabaseType.MSSQL: + case DatabaseType.DWSQL: + return new SqlConnectionStringBuilder(connectionString).ToString(); + default: + return connectionString; + } + } + catch (Exception ex) + { + // Log the exception if a logger is provided + logger?.LogWarning(ex, "Failed to parse connection string for database type {DatabaseType}. Returning original connection string.", dbType); + // If the connection string cannot be parsed by the builder, + // return the original string to avoid failing the health check. + return connectionString; + } + } } } diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 333bf57234..c24d8be8a8 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -3,6 +3,8 @@ using System; using System.IO.Abstractions; +using System.Linq; +using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Threading.Tasks; @@ -53,6 +55,7 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Identity.Client; using NodaTime; using OpenTelemetry.Exporter; using OpenTelemetry.Logs; @@ -66,6 +69,7 @@ using ZiggyCreatures.Caching.Fusion.Backplane.StackExchangeRedis; using ZiggyCreatures.Caching.Fusion.Serialization.SystemTextJson; using CorsOptions = Azure.DataApiBuilder.Config.ObjectModel.CorsOptions; +using LogLevel = Microsoft.Extensions.Logging.LogLevel; namespace Azure.DataApiBuilder.Service { @@ -250,6 +254,38 @@ public void ConfigureServices(IServiceCollection services) // within these factories the various instances will be created based on the database type and datasourceName. services.AddSingleton(); + // Register IOboTokenProvider only when user-delegated auth is configured. + // This avoids registering a null singleton and supports hot-reload scenarios. + // Requires environment variables: DAB_OBO_CLIENT_ID, DAB_OBO_TENANT_ID, DAB_OBO_CLIENT_SECRET + // + // Design note: A single IOboTokenProvider is registered using one Azure AD app registration. + // Multiple databases with different database-audience values ARE supported - the audience + // is passed to GetAccessTokenOnBehalfOfAsync() at query execution time, allowing the same + // MSAL client to acquire tokens for different resource servers. + if (IsOboConfigured()) + { + // Register IMsalClientWrapper for dependency injection + services.AddSingleton(serviceProvider => + { + string? clientId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR); + string? tenantId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR); + string? clientSecret = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR); + + string authority = $"https://login.microsoftonline.com/{tenantId}"; + + IConfidentialClientApplication msalClient = ConfidentialClientApplicationBuilder + .Create(clientId) + .WithAuthority(authority) + .WithClientSecret(clientSecret) + .Build(); + + return new MsalClientWrapper(msalClient); + }); + + // Register OboSqlTokenProvider with dependencies from DI + services.AddSingleton(); + } + services.AddSingleton(); services.AddSingleton(); @@ -435,7 +471,7 @@ public void ConfigureServices(IServiceCollection services) else { // NOTE: this is done to reuse the same connection multiplexer for both the cache and backplane - Task connectionMultiplexerTask = ConnectionMultiplexer.ConnectAsync(level2CacheOptions.ConnectionString); + Task connectionMultiplexerTask = CreateConnectionMultiplexerAsync(level2CacheOptions.ConnectionString); fusionCacheBuilder .WithSerializer(new FusionCacheSystemTextJsonSerializer()) @@ -467,9 +503,94 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); + // Add Response Compression services based on config + ConfigureResponseCompression(services, runtimeConfig); + services.AddControllers(); } + /// + /// Creates a ConnectionMultiplexer for Redis with support for Azure Entra authentication. + /// + /// The Redis connection string. + /// A task that represents the asynchronous operation. The task result contains the connected IConnectionMultiplexer. + private static async Task CreateConnectionMultiplexerAsync(string connectionString) + { + ConfigurationOptions options = ConfigurationOptions.Parse(connectionString); + + if (ShouldUseEntraAuthForRedis(options)) + { + options = await options.ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()); + } + + return await ConnectionMultiplexer.ConnectAsync(options); + } + + /// + /// Determines whether Azure Entra authentication should be used. + /// Conditions: + /// - No password provided + /// - At least one endpoint is NOT localhost/loopback + /// + /// The Redis configuration options. + /// True if Azure Entra authentication should be used; otherwise, false. + /// Internal for testing. + internal static bool ShouldUseEntraAuthForRedis(ConfigurationOptions options) + { + // Determine if an endpoint is localhost/loopback + static bool IsLocalhostEndpoint(EndPoint ep) => ep switch + { + DnsEndPoint dns => string.Equals(dns.Host, "localhost", StringComparison.OrdinalIgnoreCase), + IPEndPoint ip => IPAddress.IsLoopback(ip.Address), + _ => false, + }; + + return string.IsNullOrEmpty(options.Password) + && options.EndPoints.Any(ep => !IsLocalhostEndpoint(ep)); + } + + /// + /// Configures HTTP response compression based on the runtime configuration. + /// Compression is applied at the middleware level and supports Gzip and Brotli. + /// Applies to REST, GraphQL, and MCP endpoints. + /// + private void ConfigureResponseCompression(IServiceCollection services, RuntimeConfig? runtimeConfig) + { + CompressionLevel compressionLevel = runtimeConfig?.Runtime?.Compression?.Level ?? CompressionOptions.DEFAULT_LEVEL; + + // Only configure compression if level is not None + if (compressionLevel == CompressionLevel.None) + { + return; + } + + System.IO.Compression.CompressionLevel systemCompressionLevel = compressionLevel switch + { + CompressionLevel.Fastest => System.IO.Compression.CompressionLevel.Fastest, + CompressionLevel.Optimal => System.IO.Compression.CompressionLevel.Optimal, + _ => System.IO.Compression.CompressionLevel.Optimal + }; + + services.AddResponseCompression(options => + { + options.EnableForHttps = true; + options.Providers.Add(); + options.Providers.Add(); + }); + + services.Configure(options => + { + options.Level = systemCompressionLevel; + }); + + services.Configure(options => + { + options.Level = systemCompressionLevel; + }); + + _logger.LogInformation("Response compression enabled with level '{compressionLevel}' for REST, GraphQL, and MCP endpoints.", compressionLevel); + } + /// /// Configure GraphQL services within the service collection of the /// request pipeline. @@ -615,6 +736,13 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, RuntimeC ); } + // Response compression middleware should be placed early in the pipeline. + // Only use if compression is not set to None. + if (runtimeConfig?.Runtime?.Compression?.Level is not CompressionLevel.None) + { + app.UseResponseCompression(); + } + // URL Rewrite middleware MUST be called prior to UseRouting(). // https://andrewlock.net/understanding-pathbase-in-aspnetcore/#placing-usepathbase-in-the-correct-location app.UseCorrelationIdMiddleware(); @@ -1048,12 +1176,6 @@ private async Task PerformOnConfigChangeAsync(IApplicationBuilder app) runtimeConfigValidator.ValidateConfigProperties(); - if (runtimeConfig.IsDevelopmentMode()) - { - // Running only in developer mode to ensure fast and smooth startup in production. - runtimeConfigValidator.ValidatePermissionsInConfig(runtimeConfig); - } - IMetadataProviderFactory sqlMetadataProviderFactory = app.ApplicationServices.GetRequiredService(); await sqlMetadataProviderFactory.InitializeAsync(); @@ -1147,6 +1269,42 @@ private static bool IsUIEnabled(RuntimeConfig? runtimeConfig, IWebHostEnvironmen return (runtimeConfig is not null && runtimeConfig.IsDevelopmentMode()) || env.IsDevelopment(); } + /// + /// Checks whether On-Behalf-Of (OBO) authentication is configured by verifying that + /// the required environment variables are set and the config has user-delegated auth enabled. + /// + /// True if OBO is configured and ready to use; otherwise, false. + private bool IsOboConfigured() + { + // Check required environment variables first (fast path) + string? clientId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_ID_ENV_VAR); + string? tenantId = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_TENANT_ID_ENV_VAR); + string? clientSecret = Environment.GetEnvironmentVariable(UserDelegatedAuthOptions.DAB_OBO_CLIENT_SECRET_ENV_VAR); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(tenantId) || string.IsNullOrEmpty(clientSecret)) + { + return false; + } + + // Check if any data source has user-delegated auth enabled + RuntimeConfig? config = _configProvider?.TryGetConfig(out RuntimeConfig? c) == true ? c : null; + if (config is null) + { + return false; + } + + foreach (DataSource ds in config.ListAllDataSources()) + { + if (ds.IsUserDelegatedAuthEnabled && + !string.IsNullOrEmpty(ds.UserDelegatedAuth!.DatabaseAudience)) + { + return true; + } + } + + return false; + } + /// /// Adds all of the class namespaces that have loggers that the user is able to change ///