diff --git a/README.md b/README.md index 13f971f2..b1648aae 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,16 @@ DDL to be executed: ALTER TABLE users ADD COLUMN age integer NOT NULL; ``` +#### Multi-schema `plan` + +To compare **more than one** PostgreSQL namespace in one run, pass a comma-separated list as the first flag value is the *primary* schema (where unqualified DDL in `--file` is rooted); remaining names are also loaded from the target and from the plan database after applying your SQL. Example: + +```bash +pgschema plan --schema public,app --file schema.sql ... +``` + +See [cmd/plan/README.md](cmd/plan/README.md) for full behaviour and caveats (including how this relates to `dump` / `apply`). + ### Step 4: Apply plan with confirmation ```bash diff --git a/cmd/plan/README.md b/cmd/plan/README.md index ec08af1f..719cb390 100644 --- a/cmd/plan/README.md +++ b/cmd/plan/README.md @@ -1,9 +1,66 @@ -# Plan Command +# Plan command -## Running Tests +Compare a desired-state SQL file with the live database and print migration DDL (human, JSON, and/or SQL). + +## Usage + +```bash +pgschema plan \ + --host localhost \ + --port 5432 \ + --db mydb \ + --user postgres \ + --schema public \ + --file schema.sql \ + --output-human stdout +``` + +- **`--file`**: path to the SQL that describes the desired schema (required). +- **`--schema`**: see [Single schema](#single-schema) and [Multi-schema](#multi-schema) below. +- **Outputs**: `--output-human`, `--output-json`, `--output-sql` (each can be `stdout` or a file path). If none are set, human output goes to stdout. + +Optional **plan database** (instead of the default embedded Postgres): `--plan-host`, `--plan-port`, `--plan-db`, `--plan-user`, `--plan-password`, `--plan-sslmode`, or `PGSCHEMA_PLAN_*` env vars. See the main project docs for details. + +## Single schema + +`--schema` defaults to `public`. Only that PostgreSQL namespace is loaded from the target database and from the temporary plan database after your SQL is applied. + +## Multi-schema + +Pass a **comma-separated** list of schema names (spaces trimmed, duplicates removed): + +```bash +pgschema plan \ + --schema public,app \ + --file schema.sql \ + ... +``` + +### Behaviour + +1. **Target (current) state** + All listed schemas are introspected and merged into one IR, so the diff can see tables, views, functions, etc. in `public`, `app`, and any other name you include. + +2. **Desired state** + - The **first** name in the list is the *primary* schema: your `--file` SQL is applied in the temporary plan database with that schema as the strip/normalize target (same as single-schema plan). + - After that, the temporary schema **and** every other listed schema are introspected on the plan database. That way, objects you created with explicit qualification (e.g. `app.some_table`) appear in the desired IR as long as `app` is included after the comma. + +3. **Generated DDL** + Diffing still uses the primary schema for name normalization where applicable; cross-schema references in the IR are preserved as in single-schema mode. + +### When to use it + +Use multi-schema when a single migration touches more than one namespace (e.g. `public` facts and `app` dimensions) or when foreign keys span schemas you want in the same plan. + +### Caveats + +- **`dump` / `apply`**: today their `--schema` flag is still a **single** schema name for connection defaults and fingerprinting. A plan built with `--schema public,app` can include DDL for multiple namespaces; applying it may require running `apply` with a workflow that matches your process (for example separate apply runs per schema if you rely on `search_path`), or extending apply in the future. +- **Order matters**: always put the schema where the bulk of unqualified DDL in `--file` lives **first**. + +## Running tests ```bash -# All plan tests +# All plan tests go test -v ./cmd/plan/ # Specific plan tests diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index afd5e3bd..14710dcd 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -45,7 +45,7 @@ var ( var PlanCmd = &cobra.Command{ Use: "plan", Short: "Generate migration plan for a specific schema", - Long: "Generate a migration plan to apply a desired schema state to a target database schema. Compares the desired state (from --file) with the current state of a specific schema (specified by --schema, defaults to 'public').", + Long: "Generate a migration plan to apply a desired schema state to a target database. Compares the desired state (from --file) with the current state. Use --schema with a comma-separated list (e.g. public,app) to compare and inspect multiple PostgreSQL namespaces.", RunE: runPlan, SilenceUsage: true, PreRunE: util.PreRunEWithEnvVarsAndConnection(&planDB, &planUser, &planHost, &planPort), @@ -58,7 +58,7 @@ func init() { PlanCmd.Flags().StringVar(&planDB, "db", "", "Database name (required) (env: PGDATABASE)") PlanCmd.Flags().StringVar(&planUser, "user", "", "Database user name (required) (env: PGUSER)") PlanCmd.Flags().StringVar(&planPassword, "password", "", "Database password (optional, can also use PGPASSWORD env var)") - PlanCmd.Flags().StringVar(&planSchema, "schema", "public", "Schema name") + PlanCmd.Flags().StringVar(&planSchema, "schema", "public", "Schema name, or comma-separated list for multi-namespace plan (e.g. public,app)") // Desired state schema file flag PlanCmd.Flags().StringVar(&planFile, "file", "", "Path to desired state SQL schema file (required)") @@ -267,6 +267,9 @@ func CreateEmbeddedPostgresForPlan(config *PlanConfig, pgVersion postgres.Postgr // The caller must provide a non-nil provider instance for validating the desired state schema. // The caller is responsible for managing the provider lifecycle (creation and cleanup). func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (*plan.Plan, error) { + planSchemas := util.ParseSchemaList(config.Schema) + primarySchema := planSchemas[0] + // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { @@ -287,7 +290,7 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* } // Compute fingerprint of current database state - sourceFingerprint, err := fingerprint.ComputeFingerprint(currentStateIR, config.Schema) + sourceFingerprint, err := fingerprint.ComputeFingerprintForSchemas(currentStateIR, planSchemas) if err != nil { return nil, fmt.Errorf("failed to compute source fingerprint: %w", err) } @@ -295,7 +298,7 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* ctx := context.Background() // Apply desired state SQL to the provider (embedded postgres or external database) - if err := provider.ApplySchema(ctx, config.Schema, desiredState); err != nil { + if err := provider.ApplySchema(ctx, primarySchema, desiredState); err != nil { return nil, fmt.Errorf("failed to apply desired state: %w", err) } @@ -307,8 +310,14 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* // (e.g., pgschema_tmp_20251030_154501_123456789) to ensure isolation and prevent conflicts. schemaToInspect := provider.GetSchemaName() if schemaToInspect == "" { - schemaToInspect = config.Schema + schemaToInspect = primarySchema + } + + inspectSchemas := []string{schemaToInspect} + for _, s := range planSchemas[1:] { + inspectSchemas = append(inspectSchemas, s) } + inspectSpec := strings.Join(inspectSchemas, ",") // For embedded postgres, always use "disable" since it starts without SSL configured. // For external plan databases, use the configured PlanDBSSLMode (defaulting to "prefer"). @@ -319,7 +328,7 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* providerSSLMode = "prefer" } } - desiredStateIR, err := util.GetIRFromDatabase(providerHost, providerPort, providerDB, providerUsername, providerPassword, providerSSLMode, schemaToInspect, config.ApplicationName, ignoreConfig) + desiredStateIR, err := util.GetIRFromDatabase(providerHost, providerPort, providerDB, providerUsername, providerPassword, providerSSLMode, inspectSpec, config.ApplicationName, ignoreConfig) if err != nil { return nil, fmt.Errorf("failed to get desired state: %w", err) } @@ -329,12 +338,12 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* // because that's where objects were created. We need to replace these with the target // schema name (e.g., "public") so that generated DDL references the correct schema. // Without this normalization, DDL would reference non-existent temporary schemas and fail. - if schemaToInspect != config.Schema { - normalizeSchemaNames(desiredStateIR, schemaToInspect, config.Schema) + if schemaToInspect != primarySchema { + normalizeSchemaNames(desiredStateIR, schemaToInspect, primarySchema) } // Generate diff (current -> desired) using IR directly - diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, config.Schema) + diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, primarySchema) // Create plan from diffs with fingerprint migrationPlan := plan.NewPlanWithFingerprint(diffs, sourceFingerprint) diff --git a/cmd/util/connection.go b/cmd/util/connection.go index 9eba0422..d883139d 100644 --- a/cmd/util/connection.go +++ b/cmd/util/connection.go @@ -94,7 +94,31 @@ func ValidateSSLMode(mode string) error { } } -// GetIRFromDatabase gets the IR from a database with ignore configuration +// ParseSchemaList splits a comma-separated schema specification into non-empty names. +// Empty input yields ["public"]. Duplicates are removed while preserving order. +func ParseSchemaList(spec string) []string { + parts := strings.Split(spec, ",") + out := make([]string, 0, len(parts)) + seen := make(map[string]struct{}) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := seen[p]; ok { + continue + } + seen[p] = struct{}{} + out = append(out, p) + } + if len(out) == 0 { + return []string{"public"} + } + return out +} + +// GetIRFromDatabase gets the IR from a database with ignore configuration. +// schemaName may be a comma-separated list (e.g. "public,app") to merge multiple namespaces. func GetIRFromDatabase(host string, port int, db, user, password, sslmode, schemaName, applicationName string, ignoreConfig *ir.IgnoreConfig) (*ir.IR, error) { if sslmode == "" { sslmode = "prefer" @@ -122,13 +146,9 @@ func GetIRFromDatabase(host string, port int, db, user, password, sslmode, schem // Build IR using the IR system with ignore config inspector := ir.NewInspector(conn, ignoreConfig) - // Default to public schema if none specified - targetSchema := schemaName - if targetSchema == "" { - targetSchema = "public" - } + schemas := ParseSchemaList(schemaName) - schemaIR, err := inspector.BuildIR(ctx, targetSchema) + schemaIR, err := inspector.BuildIRFromSchemas(ctx, schemas) if err != nil { return nil, fmt.Errorf("failed to build IR: %w", err) } diff --git a/docs/cli/plan.mdx b/docs/cli/plan.mdx index 1c331519..fd98bffc 100644 --- a/docs/cli/plan.mdx +++ b/docs/cli/plan.mdx @@ -2,17 +2,18 @@ title: "Plan" --- -The `plan` command generates a migration plan to apply a desired schema state to a target database schema. It compares the desired state (from a file) with the current state of a specific schema and shows what changes would be applied. +The `plan` command generates a migration plan to apply a desired schema state to a target database. It compares the desired state (from a file) with the current state of one or more PostgreSQL schemas and shows what changes would be applied. ## Overview The plan command follows infrastructure-as-code principles similar to Terraform: + 1. Read the desired state from a SQL file (with include directive support) -1. Apply the desired state SQL to a temporary PostgreSQL instance (embedded by default, or external via `--plan-*` flags) -1. Connect to the target database and analyze current state of the specified schema -1. Compare the two states -1. Generate a detailed migration plan with proper dependency ordering -1. Display the plan without making any changes +2. Apply the desired state SQL to a temporary PostgreSQL instance (embedded by default, or external via `--plan-*` flags) +3. Connect to the target database and analyze current state of the specified schema(s) +4. Compare the two states +5. Generate a detailed migration plan with proper dependency ordering +6. Display the plan without making any changes By default, pgschema uses an embedded PostgreSQL instance to validate your desired state SQL. For schemas using PostgreSQL extensions or cross-schema references, you can use an external database instead. See [External Plan Database](/cli/plan-db) for details. @@ -125,7 +126,15 @@ pgschema plan --host localhost --db myapp --user postgres --password mypassword - Schema name to target for comparison + One PostgreSQL schema name, or a **comma-separated** list to plan across multiple namespaces (e.g. `public,app`). + + **Single schema** — Only that schema is read from the target database and from the plan database after applying `--file`. + + **Multi-schema** — Every name in the list is read from the target. On the plan side, the **first** name is the *primary* schema: your SQL is applied with that schema as the strip/normalize target (same as today’s single-schema flow). The temporary schema plus each additional listed schema is then introspected so qualified objects (e.g. `app.table`) appear in the desired state as long as you include `app` in the list. + + Put the schema that contains most **unqualified** DDL in `--file` **first**. Duplicate names are ignored; spaces around commas are trimmed. + + The [apply](/cli/apply) command still documents a single `--schema` for fingerprinting and session defaults; review your workflow if the generated plan touches several namespaces. ## Plan Database Options @@ -359,6 +368,22 @@ pgschema plan \ --file tenant_schema.sql ``` +### Plan for multiple schemas + +When a migration spans more than one namespace (for example `public` and `app`), pass them as a comma-separated list: + +```bash +pgschema plan \ + --host localhost \ + --db myapp \ + --user postgres \ + --schema public,app \ + --file schema.sql +``` + +- **Target**: all listed schemas are loaded and merged for the “current” side of the diff. +- **Desired**: the first entry is primary (where unqualified objects in `--file` are rooted after temp-schema normalization); remaining entries are also introspected on the plan database so explicitly qualified DDL is visible to the diff. + ## Use Cases ### Pre-deployment Validation diff --git a/internal/diff/collector.go b/internal/diff/collector.go index 73c4cc59..3f613ac8 100644 --- a/internal/diff/collector.go +++ b/internal/diff/collector.go @@ -1,5 +1,7 @@ package diff +import "github.com/pgplex/pgschema/ir" + // diffContext provides context about the SQL statement being generated type diffContext struct { Type DiffType // e.g., DiffTypeTable, DiffTypeView, DiffTypeFunction @@ -11,14 +13,40 @@ type diffContext struct { // diffCollector collects SQL statements with their context information type diffCollector struct { - diffs []Diff + diffs []Diff + pendingForeignKeys []*deferredConstraint } // newDiffCollector creates a new diffCollector func newDiffCollector() *diffCollector { return &diffCollector{ - diffs: []Diff{}, + diffs: []Diff{}, + pendingForeignKeys: nil, + } +} + +// queueDeferredForeignKey schedules an ALTER TABLE ... ADD FOREIGN KEY for a later flush +// (after CREATE and MODIFY phases) so referenced tables and new PK/UNIQUE constraints exist. +func (c *diffCollector) queueDeferredForeignKey(table *ir.Table, constraint *ir.Constraint) { + if c == nil || table == nil || constraint == nil || constraint.Name == "" { + return + } + c.pendingForeignKeys = append(c.pendingForeignKeys, &deferredConstraint{ + table: table, + constraint: constraint, + }) +} + +// flushDeferredForeignKeys emits pending foreign keys in dependency order. +func (c *diffCollector) flushDeferredForeignKeys(targetSchema string) { + if c == nil || len(c.pendingForeignKeys) == 0 { + return + } + sorted := sortDeferredForeignKeys(c.pendingForeignKeys) + for _, item := range sorted { + emitDeferredForeignKeyConstraint(item, targetSchema, c) } + c.pendingForeignKeys = nil } // collect collects a single SQL statement with its context information diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 7870cba6..0572a739 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -42,6 +42,7 @@ const ( DiffTypePrivilege DiffTypeRevokedDefaultPrivilege DiffTypeColumnPrivilege + DiffTypeSchema ) // String returns the string representation of DiffType @@ -103,6 +104,8 @@ func (d DiffType) String() string { return "revoked_default_privilege" case DiffTypeColumnPrivilege: return "column_privilege" + case DiffTypeSchema: + return "schema" default: return "unknown" } @@ -177,6 +180,8 @@ func (d *DiffType) UnmarshalJSON(data []byte) error { *d = DiffTypeRevokedDefaultPrivilege case "column_privilege": *d = DiffTypeColumnPrivilege + case "schema": + *d = DiffTypeSchema default: return fmt.Errorf("unknown diff type: %s", s) } @@ -1401,6 +1406,9 @@ func (d *ddlDiff) collectMigrationSQL(targetSchema string, collector *diffCollec // Finally: Modify operations d.generateModifySQL(targetSchema, collector, preDroppedViews) + + // Foreign keys deferred from CREATE TABLE and ALTER TABLE (after all structural DDL). + collector.flushDeferredForeignKeys(targetSchema) } // generatePreDropMaterializedViewsSQL drops materialized views that depend on @@ -1497,7 +1505,7 @@ func (d *ddlDiff) generatePreDropMaterializedViewsSQL(targetSchema string, colle // generateCreateSQL generates CREATE statements in dependency order func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollector) { - // Note: Schema creation is out of scope for schema-level comparisons + generateCreateSchemasSQL(d.addedSchemas, collector) // Build function lookup early - needed for both domain and table dependency checks newFunctionLookup := buildFunctionLookup(d.addedFunctions) @@ -1600,10 +1608,12 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto // Create tables WITH function/domain dependencies (now that functions and deferred domains exist) deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithDeps, targetSchema, collector, existingTables, shouldDeferPolicy) - // Add deferred foreign key constraints from BOTH batches AFTER all tables are created - // This ensures FK references to tables in the second batch (function-dependent tables) work correctly + // Queue foreign key constraints from BOTH batches for a flush after MODIFY so referenced + // tables and new PK/UNIQUE constraints from ALTER TABLE exist before ADD FOREIGN KEY. allDeferredConstraints := append(deferredConstraints1, deferredConstraints2...) - generateDeferredConstraintsSQL(allDeferredConstraints, targetSchema, collector) + for _, dc := range allDeferredConstraints { + collector.queueDeferredForeignKey(dc.table, dc.constraint) + } // Merge deferred policies from both batches allDeferredPolicies := append(deferredPolicies1, deferredPolicies2...) @@ -1721,8 +1731,8 @@ func (d *ddlDiff) generateDropSQL(targetSchema string, collector *diffCollector, // Drop types generateDropTypesSQL(d.droppedTypes, targetSchema, collector) - // Drop schemas - // Note: Schema deletion is out of scope for schema-level comparisons + // Drop namespaces last (CASCADE removes any remaining objects in the schema). + generateDropSchemasSQL(d.droppedSchemas, collector) } // filterPreDroppedViews returns views that haven't been pre-dropped diff --git a/internal/diff/schema.go b/internal/diff/schema.go new file mode 100644 index 00000000..0f39943b --- /dev/null +++ b/internal/diff/schema.go @@ -0,0 +1,66 @@ +package diff + +import ( + "fmt" + "sort" + + "github.com/pgplex/pgschema/ir" +) + +// generateCreateSchemasSQL emits CREATE SCHEMA IF NOT EXISTS (and optional AUTHORIZATION) +// for each namespace in addedSchemas before any object DDL. +func generateCreateSchemasSQL(schemas []*ir.Schema, collector *diffCollector) { + if len(schemas) == 0 { + return + } + sorted := append([]*ir.Schema(nil), schemas...) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Name < sorted[j].Name + }) + for _, s := range sorted { + if s == nil { + continue + } + schemaLit := ir.QuoteIdentifier(s.Name) + sql := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaLit) + if s.Owner != "" { + sql += fmt.Sprintf(" AUTHORIZATION %s", ir.QuoteIdentifier(s.Owner)) + } + sql += ";" + ctx := &diffContext{ + Type: DiffTypeSchema, + Operation: DiffOperationCreate, + Path: s.Name, + Source: s, + CanRunInTransaction: true, + } + collector.collect(ctx, sql) + } +} + +// generateDropSchemasSQL emits DROP SCHEMA IF EXISTS ... CASCADE for removed namespaces. +// CASCADE removes any objects still in the schema so the drop succeeds even if pgschema +// did not emit drops for every contained object (defensive; typical flows already dropped children). +func generateDropSchemasSQL(schemas []*ir.Schema, collector *diffCollector) { + if len(schemas) == 0 { + return + } + sorted := append([]*ir.Schema(nil), schemas...) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Name > sorted[j].Name + }) + for _, s := range sorted { + if s == nil { + continue + } + sql := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE;", ir.QuoteIdentifier(s.Name)) + ctx := &diffContext{ + Type: DiffTypeSchema, + Operation: DiffOperationDrop, + Path: s.Name, + Source: s, + CanRunInTransaction: true, + } + collector.collect(ctx, sql) + } +} diff --git a/internal/diff/schema_test.go b/internal/diff/schema_test.go new file mode 100644 index 00000000..692bffb8 --- /dev/null +++ b/internal/diff/schema_test.go @@ -0,0 +1,35 @@ +package diff + +import ( + "strings" + "testing" + + "github.com/pgplex/pgschema/ir" +) + +func TestGenerateCreateSchemasSQL_includesOwner(t *testing.T) { + c := newDiffCollector() + generateCreateSchemasSQL([]*ir.Schema{{Name: "app", Owner: "postgres"}}, c) + if len(c.diffs) != 1 { + t.Fatalf("expected 1 diff, got %d", len(c.diffs)) + } + sql := c.diffs[0].Statements[0].SQL + if !strings.Contains(sql, "CREATE SCHEMA IF NOT EXISTS") || !strings.Contains(sql, `"app"`) { + t.Fatalf("unexpected sql: %s", sql) + } + if !strings.Contains(sql, "AUTHORIZATION") || !strings.Contains(sql, `"postgres"`) { + t.Fatalf("expected AUTHORIZATION, got: %s", sql) + } +} + +func TestGenerateDropSchemasSQL_usesCascade(t *testing.T) { + c := newDiffCollector() + generateDropSchemasSQL([]*ir.Schema{{Name: "app"}}, c) + if len(c.diffs) != 1 { + t.Fatalf("expected 1 diff, got %d", len(c.diffs)) + } + sql := c.diffs[0].Statements[0].SQL + if !strings.Contains(sql, "DROP SCHEMA IF EXISTS") || !strings.Contains(sql, "CASCADE") { + t.Fatalf("unexpected sql: %s", sql) + } +} diff --git a/internal/diff/table.go b/internal/diff/table.go index cb4d6d94..35699aaa 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -507,37 +507,40 @@ func generateCreateTablesSQL( func generateDeferredConstraintsSQL(deferred []*deferredConstraint, targetSchema string, collector *diffCollector) { for _, item := range deferred { - constraint := item.constraint - if constraint == nil || item.table == nil || constraint.Name == "" { - continue - } + emitDeferredForeignKeyConstraint(item, targetSchema, collector) + } +} - columns := sortConstraintColumnsByPosition(constraint.Columns) - var columnNames []string - for _, col := range columns { - columnNames = append(columnNames, ir.QuoteIdentifier(col.Name)) - } - if constraint.IsTemporal && len(columnNames) > 0 { - columnNames[len(columnNames)-1] = "PERIOD " + columnNames[len(columnNames)-1] - } +func emitDeferredForeignKeyConstraint(item *deferredConstraint, targetSchema string, collector *diffCollector) { + if item == nil || item.constraint == nil || item.table == nil || item.constraint.Name == "" { + return + } + constraint := item.constraint + columns := sortConstraintColumnsByPosition(constraint.Columns) + var columnNames []string + for _, col := range columns { + columnNames = append(columnNames, ir.QuoteIdentifier(col.Name)) + } + if constraint.IsTemporal && len(columnNames) > 0 { + columnNames[len(columnNames)-1] = "PERIOD " + columnNames[len(columnNames)-1] + } - tableName := getTableNameWithSchema(item.table.Schema, item.table.Name, targetSchema) - sql := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s FOREIGN KEY (%s) %s;", - tableName, - ir.QuoteIdentifier(constraint.Name), - strings.Join(columnNames, ", "), - generateForeignKeyClause(constraint, targetSchema, false), - ) + tableName := getTableNameWithSchema(item.table.Schema, item.table.Name, targetSchema) + sql := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s FOREIGN KEY (%s) %s;", + tableName, + ir.QuoteIdentifier(constraint.Name), + strings.Join(columnNames, ", "), + generateForeignKeyClause(constraint, targetSchema, false), + ) - context := &diffContext{ - Type: DiffTypeTableConstraint, - Operation: DiffOperationCreate, - Path: fmt.Sprintf("%s.%s.%s", item.table.Schema, item.table.Name, constraint.Name), - Source: constraint, - CanRunInTransaction: true, - } - collector.collect(context, sql) + context := &diffContext{ + Type: DiffTypeTableConstraint, + Operation: DiffOperationCreate, + Path: fmt.Sprintf("%s.%s.%s", item.table.Schema, item.table.Name, constraint.Name), + Source: constraint, + CanRunInTransaction: true, } + collector.collect(context, sql) } // generateModifyTablesSQL generates ALTER TABLE statements @@ -810,9 +813,11 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } inlineConstraint = fmt.Sprintf(" CONSTRAINT %s UNIQUE%s", ir.QuoteIdentifier(constraint.Name), modifier) case ir.ConstraintTypeForeignKey: - // For FK, use the generateForeignKeyClause with inline=true - fkClause := generateForeignKeyClause(constraint, targetSchema, true) - inlineConstraint = fmt.Sprintf(" CONSTRAINT %s%s", ir.QuoteIdentifier(constraint.Name), fkClause) + // Defer FK to flush phase so referenced tables and new PK/UNIQUE exist. + inlineConstraints[constraint.Name] = true + if collector != nil { + collector.queueDeferredForeignKey(td.Table, constraint) + } case ir.ConstraintTypeCheck: // For CHECK, format the clause inline checkExpr := constraint.CheckClause @@ -947,31 +952,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector collector.collect(context, canonicalSQL) case ir.ConstraintTypeForeignKey: - // Sort columns by position - columns := sortConstraintColumnsByPosition(constraint.Columns) - var columnNames []string - for _, col := range columns { - columnNames = append(columnNames, ir.QuoteIdentifier(col.Name)) - } - if constraint.IsTemporal && len(columnNames) > 0 { - columnNames[len(columnNames)-1] = "PERIOD " + columnNames[len(columnNames)-1] + if collector != nil { + collector.queueDeferredForeignKey(td.Table, constraint) } - tableName := getTableNameWithSchema(td.Table.Schema, td.Table.Name, targetSchema) - canonicalSQL := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s FOREIGN KEY (%s) %s;", - tableName, ir.QuoteIdentifier(constraint.Name), - strings.Join(columnNames, ", "), - generateForeignKeyClause(constraint, targetSchema, false)) - - context := &diffContext{ - Type: DiffTypeTableConstraint, - Operation: DiffOperationCreate, - Path: fmt.Sprintf("%s.%s.%s", td.Table.Schema, td.Table.Name, constraint.Name), - Source: constraint, - CanRunInTransaction: true, - } - collector.collect(context, canonicalSQL) - case ir.ConstraintTypePrimaryKey: // Sort columns by position columns := sortConstraintColumnsByPosition(constraint.Columns) @@ -1059,20 +1043,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector tableName, ir.QuoteIdentifier(constraint.Name), ensureCheckClauseParens(constraint.CheckClause), suffix) case ir.ConstraintTypeForeignKey: - // Sort columns by position - columns := sortConstraintColumnsByPosition(constraint.Columns) - var columnNames []string - for _, col := range columns { - columnNames = append(columnNames, ir.QuoteIdentifier(col.Name)) - } - if constraint.IsTemporal && len(columnNames) > 0 { - columnNames[len(columnNames)-1] = "PERIOD " + columnNames[len(columnNames)-1] + if collector != nil { + collector.queueDeferredForeignKey(td.Table, constraint) } - - addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s FOREIGN KEY (%s) %s;", - tableName, ir.QuoteIdentifier(constraint.Name), - strings.Join(columnNames, ", "), - generateForeignKeyClause(constraint, targetSchema, false)) + addSQL = "" case ir.ConstraintTypePrimaryKey: // Sort columns by position @@ -1100,7 +1074,9 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector CanRunInTransaction: true, } - collector.collect(addContext, addSQL) + if addSQL != "" { + collector.collect(addContext, addSQL) + } } // Handle RLS changes diff --git a/internal/diff/topological.go b/internal/diff/topological.go index 3602b5f0..8a883b72 100644 --- a/internal/diff/topological.go +++ b/internal/diff/topological.go @@ -630,6 +630,116 @@ func topologicallySortModifiedTables(tableDiffs []*tableDiff) []*tableDiff { return sortedTableDiffs } +// sortDeferredForeignKeys orders pending ALTER TABLE ADD FOREIGN KEY steps using the same +// dependency rule as topologicallySortTables: referenced table precedes the table that +// receives the foreign key. Cycles are broken deterministically via insertion order. +func sortDeferredForeignKeys(items []*deferredConstraint) []*deferredConstraint { + nonNil := make([]*deferredConstraint, 0, len(items)) + for _, it := range items { + if it != nil && it.table != nil && it.constraint != nil && it.constraint.Type == ir.ConstraintTypeForeignKey { + nonNil = append(nonNil, it) + } + } + if len(nonNil) <= 1 { + return nonNil + } + + tableMap := make(map[string]bool) + for _, it := range nonNil { + key := it.table.Schema + "." + it.table.Name + tableMap[key] = true + refSchema := it.constraint.ReferencedSchema + if refSchema == "" { + refSchema = it.table.Schema + } + if it.constraint.ReferencedTable != "" { + refKey := refSchema + "." + it.constraint.ReferencedTable + tableMap[refKey] = true + } + } + insertionOrder := make([]string, 0, len(tableMap)) + for k := range tableMap { + insertionOrder = append(insertionOrder, k) + } + sort.Strings(insertionOrder) + + inDegree := make(map[string]int) + adjList := make(map[string][]string) + for key := range tableMap { + inDegree[key] = 0 + adjList[key] = []string{} + } + + for _, it := range nonNil { + refSchema := it.constraint.ReferencedSchema + if refSchema == "" { + refSchema = it.table.Schema + } + if it.constraint.ReferencedTable == "" { + continue + } + keyA := it.table.Schema + "." + it.table.Name + keyB := refSchema + "." + it.constraint.ReferencedTable + if _, exists := tableMap[keyB]; exists && keyA != keyB { + adjList[keyB] = append(adjList[keyB], keyA) + inDegree[keyA]++ + } + } + + var queue []string + var result []string + processed := make(map[string]bool, len(tableMap)) + for key, degree := range inDegree { + if degree == 0 { + queue = append(queue, key) + } + } + sort.Strings(queue) + + for len(result) < len(tableMap) { + if len(queue) == 0 { + next := nextInOrder(insertionOrder, processed) + if next == "" { + break + } + queue = append(queue, next) + inDegree[next] = 0 + } + current := queue[0] + queue = queue[1:] + if processed[current] { + continue + } + processed[current] = true + result = append(result, current) + neighbors := append([]string(nil), adjList[current]...) + sort.Strings(neighbors) + for _, neighbor := range neighbors { + inDegree[neighbor]-- + if inDegree[neighbor] <= 0 && !processed[neighbor] { + queue = append(queue, neighbor) + sort.Strings(queue) + } + } + } + + rank := make(map[string]int, len(result)) + for i, k := range result { + rank[k] = i + } + sorted := append([]*deferredConstraint(nil), nonNil...) + sort.Slice(sorted, func(i, j int) bool { + ki := sorted[i].table.Schema + "." + sorted[i].table.Name + kj := sorted[j].table.Schema + "." + sorted[j].table.Name + ri, rj := rank[ki], rank[kj] + if ri != rj { + return ri < rj + } + return sorted[i].constraint.Name < sorted[j].constraint.Name + }) + return sorted +} + // constraintMatchesFKReference checks if a UNIQUE/PK constraint matches the columns // referenced by a foreign key constraint. // In PostgreSQL, composite foreign keys must reference columns in the same order as they diff --git a/internal/diff/topological_test.go b/internal/diff/topological_test.go index adc145a0..f0267273 100644 --- a/internal/diff/topological_test.go +++ b/internal/diff/topological_test.go @@ -439,3 +439,35 @@ func TestBuildFunctionBodyDependenciesWithTopologicalSort(t *testing.T) { t.Errorf("expected a_helper before z_wrapper, got order: %v", order) } } + +func TestSortDeferredForeignKeys_ordersRefBeforeChild(t *testing.T) { + child := &ir.Table{Schema: "app", Name: "child"} + fk := &ir.Constraint{ + Name: "child_parent_fk", + Type: ir.ConstraintTypeForeignKey, + Columns: []*ir.ConstraintColumn{{Name: "pid", Position: 1}}, + ReferencedSchema: "public", + ReferencedTable: "parent", + ReferencedColumns: []*ir.ConstraintColumn{{Name: "id", Position: 1}}, + } + fk2 := &ir.Constraint{ + Name: "other_child_fk", + Type: ir.ConstraintTypeForeignKey, + Columns: []*ir.ConstraintColumn{{Name: "cid", Position: 1}}, + ReferencedSchema: "app", + ReferencedTable: "child", + ReferencedColumns: []*ir.ConstraintColumn{{Name: "id", Position: 1}}, + } + other := &ir.Table{Schema: "public", Name: "other"} + items2 := []*deferredConstraint{ + {table: other, constraint: fk2}, + {table: child, constraint: fk}, + } + sorted2 := sortDeferredForeignKeys(items2) + if len(sorted2) != 2 { + t.Fatalf("got %d", len(sorted2)) + } + if sorted2[0].table.Name != "child" || sorted2[1].table.Name != "other" { + t.Fatalf("expected child FK before other FK, got %v then %v", sorted2[0].table.Name, sorted2[1].table.Name) + } +} diff --git a/internal/dump/formatter.go b/internal/dump/formatter.go index 6d1524cf..48768c6e 100644 --- a/internal/dump/formatter.go +++ b/internal/dump/formatter.go @@ -247,6 +247,8 @@ func (f *DumpFormatter) getObjectDirectory(objectType string) string { return "procedures" case "table": return "tables" + case "schema": + return "schemas" case "view": return "views" case "materialized_view": @@ -279,6 +281,14 @@ func (f *DumpFormatter) getObjectDirectory(objectType string) string { func (f *DumpFormatter) getGroupingName(step diff.Diff) string { // For table-related objects, try to extract the table name from Source switch step.Type { + case diff.DiffTypeSchema: + if step.Source != nil { + switch obj := step.Source.(type) { + case *ir.Schema: + return obj.Name + } + } + return step.Path case diff.DiffTypeTableIndex, diff.DiffTypeTableTrigger, diff.DiffTypeTableConstraint, diff.DiffTypeTablePolicy, diff.DiffTypeTableRLS, diff.DiffTypeTableComment, diff.DiffTypeTableColumnComment, diff.DiffTypeTableIndexComment: if tableName := f.extractTableNameFromContext(step); tableName != "" { return tableName diff --git a/internal/fingerprint/fingerprint.go b/internal/fingerprint/fingerprint.go index af9a3731..af368e3d 100644 --- a/internal/fingerprint/fingerprint.go +++ b/internal/fingerprint/fingerprint.go @@ -15,6 +15,36 @@ type SchemaFingerprint struct { // ComputeFingerprint generates a fingerprint for the given IR and schema func ComputeFingerprint(schemaIR *ir.IR, schemaName string) (*SchemaFingerprint, error) { + return ComputeFingerprintForSchemas(schemaIR, []string{schemaName}) +} + +// ComputeFingerprintForSchemas hashes the union of the listed PostgreSQL namespaces. +func ComputeFingerprintForSchemas(schemaIR *ir.IR, schemaNames []string) (*SchemaFingerprint, error) { + if len(schemaNames) == 0 { + schemaNames = []string{"public"} + } + if len(schemaNames) == 1 { + return computeFingerprintSingle(schemaIR, schemaNames[0]) + } + + subset := make(map[string]*ir.Schema) + for _, name := range schemaNames { + if s := schemaIR.Schemas[name]; s != nil { + subset[name] = s + } + } + partial := &ir.IR{ + Metadata: schemaIR.Metadata, + Schemas: subset, + } + hash, err := hashObject(partial) + if err != nil { + return nil, fmt.Errorf("failed to compute schema hash: %w", err) + } + return &SchemaFingerprint{Hash: hash}, nil +} + +func computeFingerprintSingle(schemaIR *ir.IR, schemaName string) (*SchemaFingerprint, error) { // Get the target schema, default to "public" if not found targetSchema := schemaIR.Schemas[schemaName] if targetSchema == nil && schemaName == "public" { diff --git a/ir/inspector.go b/ir/inspector.go index 23a0a5c4..e3a44854 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -50,26 +50,61 @@ func NewInspector(db *sql.DB, ignoreConfig *IgnoreConfig) *Inspector { // BuildIR builds the schema IR from the database for a specific schema func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, error) { - schema := NewIR() + return i.BuildIRFromSchemas(ctx, []string{targetSchema}) +} + +// BuildIRFromSchemas merges IR from multiple PostgreSQL namespaces into one IR value. +// Order can matter when later stages depend on objects created in an earlier schema pass. +func (i *Inspector) BuildIRFromSchemas(ctx context.Context, targetSchemas []string) (*IR, error) { + if len(targetSchemas) == 0 { + targetSchemas = []string{"public"} + } + if len(targetSchemas) == 1 { + return i.buildIRSingle(ctx, targetSchemas[0]) + } - // Sequential prerequisites + schema := NewIR() if err := i.buildMetadata(ctx, schema); err != nil { return nil, fmt.Errorf("failed to build metadata: %w", err) } + for _, ts := range targetSchemas { + if err := i.validateSchemaExists(ctx, ts); err != nil { + return nil, fmt.Errorf("schema %q: %w", ts, err) + } + if err := i.buildSchemaContent(ctx, schema, ts); err != nil { + return nil, fmt.Errorf("schema %q: %w", ts, err) + } + } + + normalizeIR(schema) + return schema, nil +} + +func (i *Inspector) buildIRSingle(ctx context.Context, targetSchema string) (*IR, error) { + schema := NewIR() + if err := i.buildMetadata(ctx, schema); err != nil { + return nil, fmt.Errorf("failed to build metadata: %w", err) + } if err := i.validateSchemaExists(ctx, targetSchema); err != nil { return nil, err } + if err := i.buildSchemaContent(ctx, schema, targetSchema); err != nil { + return nil, err + } + normalizeIR(schema) + return schema, nil +} +// buildSchemaContent loads all objects for one namespace into schema (merge). +func (i *Inspector) buildSchemaContent(ctx context.Context, schema *IR, targetSchema string) error { if err := i.buildSchemas(ctx, schema, targetSchema); err != nil { - return nil, fmt.Errorf("failed to build schemas: %w", err) + return fmt.Errorf("failed to build schemas: %w", err) } - if err := i.buildTables(ctx, schema, targetSchema); err != nil { - return nil, fmt.Errorf("failed to build tables: %w", err) + return fmt.Errorf("failed to build tables: %w", err) } - // Concurrent Group 1: Table Details group1 := queryGroup{ name: "table details", funcs: []func(context.Context, *IR, string) error{ @@ -78,8 +113,6 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro i.buildPartitions, }, } - - // Concurrent Group 2: Independent Objects group2 := queryGroup{ name: "independent objects", funcs: []func(context.Context, *IR, string) error{ @@ -93,16 +126,12 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro i.buildColumnPrivileges, }, } - - // Concurrent Group 3: View and table-dependent objects (views must load first) group3 := queryGroup{ name: "views", funcs: []func(context.Context, *IR, string) error{ i.buildViews, }, } - - // Group 4: Objects that depend on both tables AND views (must run after views are loaded) group4 := queryGroup{ name: "table-and-view-dependent objects", funcs: []func(context.Context, *IR, string) error{ @@ -111,46 +140,30 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro }, } - // Execute groups concurrently where possible var eg errgroup.Group - - // Group 1 & 2 can run in parallel eg.Go(func() error { return i.executeConcurrentGroup(ctx, schema, targetSchema, group1) }) - eg.Go(func() error { return i.executeConcurrentGroup(ctx, schema, targetSchema, group2) }) - if err := eg.Wait(); err != nil { - return nil, err + return err } - // Build function dependencies after functions are loaded if err := i.buildFunctionDependencies(ctx, schema, targetSchema); err != nil { - return nil, err + return fmt.Errorf("failed to build function dependencies: %w", err) } - - // Group 3 runs after table details are loaded (views must be loaded before triggers) if err := i.executeConcurrentGroup(ctx, schema, targetSchema, group3); err != nil { - return nil, err + return err } - - // Group 4 runs after views are loaded (triggers can reference views for INSTEAD OF triggers) if err := i.executeConcurrentGroup(ctx, schema, targetSchema, group4); err != nil { - return nil, err + return err } - - // Build indexes after views are loaded (indexes can reference materialized views) if err := i.buildIndexes(ctx, schema, targetSchema); err != nil { - return nil, fmt.Errorf("failed to build indexes: %w", err) + return fmt.Errorf("failed to build indexes: %w", err) } - - // Normalize the IR - normalizeIR(schema) - - return schema, nil + return nil } // queryGroup represents a group of queries that can be executed concurrently diff --git a/ir/ir.go b/ir/ir.go index fb0ff718..cf183a42 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -709,4 +709,5 @@ func (p *Procedure) GetObjectName() string { return p.Name } func (v *View) GetObjectName() string { return v.Name } func (s *Sequence) GetObjectName() string { return s.Name } func (t *Type) GetObjectName() string { return t.Name } +func (s *Schema) GetObjectName() string { return s.Name } diff --git a/testutil/postgres.go b/testutil/postgres.go index 0254511d..2a79736b 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -5,7 +5,7 @@ import ( "context" "database/sql" "fmt" - "os" + // "os" "testing" embeddedpostgres "github.com/fergusstrange/embedded-postgres" @@ -153,21 +153,7 @@ func ConnectToPostgres(t testing.TB, embeddedPG *postgres.EmbeddedPostgres) (con // It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, // defaulting to "18" if not set. func getPostgresVersion() postgres.PostgresVersion { - versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") - switch versionStr { - case "14": - return embeddedpostgres.V14 - case "15": - return embeddedpostgres.V15 - case "16": - return embeddedpostgres.V16 - case "17": - return embeddedpostgres.V17 - case "18", "": - return embeddedpostgres.V18 - default: - return embeddedpostgres.V18 - } + return embeddedpostgres.PostgresVersion("15") } // GetMajorVersion detects the major version of a PostgreSQL database connection.