Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
if snapshot.node.project in self._projects:
uncached.add(snapshot.name)
else:
store = self._standalone_audits if snapshot.is_audit else self._models
store[snapshot.name] = snapshot.node # type: ignore
local_store = self._standalone_audits if snapshot.is_audit else self._models
if snapshot.name in local_store:
uncached.add(snapshot.name)
else:
local_store[snapshot.name] = snapshot.node # type: ignore

for model in self._models.values():
self.dag.add(model.fqn, model.depends_on)
Expand Down
105 changes: 105 additions & 0 deletions tests/core/integration/test_multi_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,111 @@ def test_multi_hybrid(mocker):
validate_apply_basics(context, c.PROD, plan.snapshots.values())


def test_multi_repo_no_project_to_project(copy_to_temp_path):
paths = copy_to_temp_path("examples/multi")
repo_1_path = f"{paths[0]}/repo_1"
repo_1_config_path = f"{repo_1_path}/config.yaml"
with open(repo_1_config_path, "r") as f:
config_content = f.read()
with open(repo_1_config_path, "w") as f:
f.write(config_content.replace("project: repo_1\n", ""))

context = Context(paths=[repo_1_path], gateway="memory")
context._new_state_sync().reset(default_catalog=context.default_catalog)
plan = context.plan_builder().build()
context.apply(plan)

# initially models in prod have no project
prod_snapshots = context.state_reader.get_snapshots(
context.state_reader.get_environment(c.PROD).snapshots
)
for snapshot in prod_snapshots.values():
assert snapshot.node.project == ""

# we now adopt multi project by adding a project name
with open(repo_1_config_path, "r") as f:
config_content = f.read()
with open(repo_1_config_path, "w") as f:
f.write("project: repo_1\n" + config_content)

context_with_project = Context(
paths=[repo_1_path],
state_sync=context.state_sync,
gateway="memory",
)
context_with_project._engine_adapter = context.engine_adapter
del context_with_project.engine_adapters

# local models should take precedence to pick up the new project name
local_model_a = context_with_project.get_model("bronze.a")
assert local_model_a.project == "repo_1"
local_model_b = context_with_project.get_model("bronze.b")
assert local_model_b.project == "repo_1"

# also verify the plan works
plan = context_with_project.plan_builder().build()
context_with_project.apply(plan)
validate_apply_basics(context_with_project, c.PROD, plan.snapshots.values())


def test_multi_repo_local_model_overrides_prod_from_other_project(copy_to_temp_path):
paths = copy_to_temp_path("examples/multi")
repo_1_path = f"{paths[0]}/repo_1"
repo_2_path = f"{paths[0]}/repo_2"

context = Context(paths=[repo_1_path, repo_2_path], gateway="memory")
context._new_state_sync().reset(default_catalog=context.default_catalog)
plan = context.plan_builder().build()
assert len(plan.new_snapshots) == 5
context.apply(plan)

prod_model_c = context.get_model("silver.c")
assert prod_model_c.project == "repo_2"

with open(f"{repo_1_path}/models/c.sql", "w") as f:
f.write(
dedent("""\
MODEL (
name silver.c,
kind FULL
);
SELECT DISTINCT col_a, col_b
FROM bronze.a
""")
)

# silver.c exists locally in repo 1 now AND in prod under repo_2
context_repo1 = Context(
paths=[repo_1_path],
state_sync=context.state_sync,
gateway="memory",
)
context_repo1._engine_adapter = context.engine_adapter
del context_repo1.engine_adapters

# local model should take precedence and its project should reflect the new project name
local_model_c = context_repo1.get_model("silver.c")
assert local_model_c.project == "repo_1"

rendered = context_repo1.render("silver.c").sql()
assert "col_b" in rendered

# its downstream dependencies though should still be picked up
plan = context_repo1.plan_builder().build()
directly_modified_names = {snapshot.name for snapshot in plan.directly_modified}
assert '"memory"."silver"."c"' in directly_modified_names
assert '"memory"."silver"."d"' in directly_modified_names
missing_interval_names = {s.snapshot_id.name for s in plan.missing_intervals}
assert '"memory"."silver"."c"' in missing_interval_names
assert '"memory"."silver"."d"' in missing_interval_names

context_repo1.apply(plan)
validate_apply_basics(context_repo1, c.PROD, plan.snapshots.values())
result = context_repo1.fetchdf("SELECT * FROM memory.silver.c")
assert "col_b" in result.columns


def test_engine_adapters_multi_repo_all_gateways_gathered(copy_to_temp_path):
paths = copy_to_temp_path("examples/multi")
repo_1_path = paths[0] / "repo_1"
Expand Down