Skip to content

Commit cb8acc5

Browse files
authored
feat(models): render blueprint vars in columns dict (#5845)
Signed-off-by: lafirm <136463254+lafirm@users.noreply.github.com>
1 parent b44fdf6 commit cb8acc5

9 files changed

Lines changed: 160 additions & 26 deletions

File tree

docs/concepts/models/python_models.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,33 @@ def entrypoint(
369369
)
370370
```
371371

372+
Blueprint variables can also be used as **column names and column types** in the `columns` dictionary. For example, if each blueprint produces a model with a different set of column names and types, both can be parameterized using the same `@{variable}` syntax:
373+
374+
```python linenums="1"
375+
import pandas as pd
376+
from sqlmesh import ExecutionContext, model
377+
378+
@model(
379+
"@{customer}.metrics",
380+
kind="FULL",
381+
blueprints=[
382+
{"customer": "customer1", "primary_metric": "revenue", "primary_type": "int", "secondary_metric": "cost", "secondary_type": "double"},
383+
{"customer": "customer2", "primary_metric": "sales", "primary_type": "text", "secondary_metric": "profit", "secondary_type": "double"},
384+
],
385+
columns={
386+
"@{primary_metric}": "@{primary_type}",
387+
"@{secondary_metric}": "@{secondary_type}",
388+
},
389+
)
390+
def entrypoint(context: ExecutionContext, **kwargs) -> pd.DataFrame:
391+
return pd.DataFrame({
392+
context.blueprint_var("primary_metric"): [1],
393+
context.blueprint_var("secondary_metric"): [1.5],
394+
})
395+
```
396+
397+
Global variables (defined in the project config) can also be used as column names and types in the same way.
398+
372399
Note the use of curly brace syntax `@{customer}` in the model name above. It is used to ensure SQLMesh can combine the macro variable into the model name identifier correctly - learn more [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings).
373400

374401
Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints="@gen_blueprints()"`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files.

sqlmesh/core/dialect.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,8 +1096,9 @@ def extend_sqlglot() -> None:
10961096
DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
10971097
Jinja: lambda self, e: e.name,
10981098
JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
1099-
JinjaStatement: lambda self,
1100-
e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
1099+
JinjaStatement: lambda self, e: (
1100+
f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};"
1101+
),
11011102
VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
11021103
MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
11031104
MacroFunc: _macro_func_sql,

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ def _get_source_queries(
300300
)
301301
for c in target_columns_to_types
302302
]
303-
query_factory = (
304-
lambda: exp.Select()
303+
query_factory = lambda: (
304+
exp.Select()
305305
.select(*select_columns)
306306
.from_(query_or_df.subquery("select_source_columns"))
307307
)

sqlmesh/core/model/decorator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
7474

7575
self.columns = {
7676
column_name: (
77-
column_type
78-
if isinstance(column_type, exp.DataType)
77+
column_type # Column types with macros (containing @) will be validated later after rendering
78+
if isinstance(column_type, exp.DataType) or "@" in column_type
7979
else exp.DataType.build(
8080
str(column_type), dialect=self.kwargs.get("dialect", self._dialect)
8181
)
8282
)
83-
for column_name, column_type in self.kwargs.pop("columns", {}).items()
83+
for column_name, column_type in self.kwargs.get("columns", {}).items()
8484
}
8585

8686
def __call__(
@@ -196,6 +196,8 @@ def model(
196196
if isinstance(rendered_name, exp.Expr):
197197
rendered_fields["name"] = rendered_name.sql(dialect=dialect)
198198

199+
rendered_columns = rendered_fields.get("columns")
200+
199201
rendered_defaults = (
200202
render_model_defaults(
201203
defaults=defaults,
@@ -223,7 +225,7 @@ def model(
223225
"default_catalog": default_catalog,
224226
"variables": variables,
225227
"dialect": dialect,
226-
"columns": self.columns if self.columns else None,
228+
"columns": rendered_columns if rendered_columns else None,
227229
"module_path": module_path,
228230
"macros": macros,
229231
"jinja_macros": jinja_macros,

sqlmesh/core/model/definition.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2977,7 +2977,15 @@ def render_field_value(value: t.Any) -> t.Any:
29772977
if isinstance(field_value, dict):
29782978
rendered_dict = {}
29792979
for key, value in field_value.items():
2980-
if key in RUNTIME_RENDERED_MODEL_FIELDS:
2980+
if field == "columns":
2981+
column_name = render_field_value(key)
2982+
column_type = render_field_value(value)
2983+
# If column_type is an Expr (from rendering macros), convert to string.
2984+
# Otherwise, leave it as-is (string) for the validator to parse with the correct dialect.
2985+
if isinstance(column_type, exp.Expr):
2986+
column_type = column_type.sql(dialect=dialect)
2987+
rendered_dict[column_name] = column_type
2988+
elif key in RUNTIME_RENDERED_MODEL_FIELDS:
29812989
rendered_dict[key] = parse_strings_with_macro_refs(value, dialect)
29822990
elif (
29832991
# don't parse kind auto_restatement_cron="@..." kwargs (e.g. @daily) into MacroVar

sqlmesh/lsp/reference.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,9 @@ def get_model_find_all_references(
332332
# Find the model reference at the cursor position
333333
model_at_position = next(
334334
filter(
335-
lambda ref: isinstance(ref, ModelReference)
336-
and _position_within_range(position, ref.range),
335+
lambda ref: (
336+
isinstance(ref, ModelReference) and _position_within_range(position, ref.range)
337+
),
337338
get_model_definitions_for_a_path(lint_context, document_uri),
338339
),
339340
None,
@@ -486,8 +487,9 @@ def get_macro_find_all_references(
486487
# Find the macro reference at the cursor position
487488
macro_at_position = next(
488489
filter(
489-
lambda ref: isinstance(ref, MacroReference)
490-
and _position_within_range(position, ref.range),
490+
lambda ref: (
491+
isinstance(ref, MacroReference) and _position_within_range(position, ref.range)
492+
),
491493
get_macro_definitions_for_a_path(lsp_context, document_uri),
492494
),
493495
None,
@@ -517,9 +519,11 @@ def get_macro_find_all_references(
517519

518520
# Get macro references that point to the same macro definition
519521
matching_refs = filter(
520-
lambda ref: isinstance(ref, MacroReference)
521-
and ref.path == target_macro_path
522-
and ref.target_range == target_macro_target_range,
522+
lambda ref: (
523+
isinstance(ref, MacroReference)
524+
and ref.path == target_macro_path
525+
and ref.target_range == target_macro_target_range
526+
),
523527
get_macro_definitions_for_a_path(lsp_context, file_uri),
524528
)
525529

tests/core/test_model.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def test_model_union_query(sushi_context, assert_exp_eq):
334334
"@get_date() == '1996-02-10'",
335335
"'all'",
336336
3,
337-
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n",
337+
lambda expected_select: (
338+
f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n"
339+
),
338340
),
339341
# Test case 4: DISTINCT type
340342
(
@@ -374,7 +376,9 @@ def test_model_union_query(sushi_context, assert_exp_eq):
374376
"",
375377
"",
376378
3,
377-
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n",
379+
lambda expected_select: (
380+
f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n"
381+
),
378382
),
379383
# Test case 9: Missing union type AND condition one table
380384
(
@@ -10353,6 +10357,94 @@ def entrypoint(context, *args, **kwargs):
1035310357
assert ctx.fetchdf("SELECT * FROM test_schema2.foo").to_dict() == {"id": {0: 1}}
1035410358

1035510359

10360+
def test_python_model_blueprint_column_names(tmp_path: Path) -> None:
10361+
"""Blueprint variables can be used as column names and types in Python model definitions."""
10362+
py_model = tmp_path / "models" / "blueprint_col_names.py"
10363+
py_model.parent.mkdir(parents=True, exist_ok=True)
10364+
py_model.write_text(
10365+
"""
10366+
import pandas as pd # noqa: TID253
10367+
from sqlmesh import model
10368+
10369+
@model(
10370+
"test_schema.@model_name",
10371+
blueprints=[
10372+
{"model_name": "hotel_revenue", "col_a": "revenue", "type_a": "int", "col_b": "cost", "type_b": "double"},
10373+
{"model_name": "coffee_sales", "col_a": "sales", "type_a": "bigint", "col_b": "profit", "type_b": "text"},
10374+
],
10375+
kind="FULL",
10376+
columns={
10377+
"@{col_a}": "@{type_a}",
10378+
"@{col_b}": "@{type_b}",
10379+
},
10380+
)
10381+
def entrypoint(context, *args, **kwargs):
10382+
return pd.DataFrame({
10383+
context.blueprint_var("col_a"): [1],
10384+
context.blueprint_var("col_b"): [1.5],
10385+
})
10386+
"""
10387+
)
10388+
10389+
ctx = Context(
10390+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
10391+
paths=tmp_path,
10392+
)
10393+
assert len(ctx.models) == 2
10394+
10395+
model1 = ctx.get_model("test_schema.hotel_revenue", raise_if_missing=True)
10396+
model2 = ctx.get_model("test_schema.coffee_sales", raise_if_missing=True)
10397+
10398+
assert model1.columns_to_types_ is not None
10399+
assert set(model1.columns_to_types_.keys()) == {"revenue", "cost"}
10400+
assert model1.columns_to_types_["revenue"] == exp.DataType.build("int")
10401+
assert model1.columns_to_types_["cost"] == exp.DataType.build("double")
10402+
10403+
assert model2.columns_to_types_ is not None
10404+
assert set(model2.columns_to_types_.keys()) == {"sales", "profit"}
10405+
assert model2.columns_to_types_["sales"] == exp.DataType.build("bigint")
10406+
assert model2.columns_to_types_["profit"] == exp.DataType.build("text")
10407+
10408+
10409+
def test_python_model_variable_column_names(tmp_path: Path) -> None:
10410+
"""Global variables can be used as column names in Python model definitions."""
10411+
py_model = tmp_path / "models" / "var_col_names.py"
10412+
py_model.parent.mkdir(parents=True, exist_ok=True)
10413+
py_model.write_text(
10414+
"""
10415+
import pandas as pd # noqa: TID253
10416+
from sqlmesh import model
10417+
10418+
@model(
10419+
"test_schema.model",
10420+
kind="FULL",
10421+
columns={
10422+
"@{metric_col}": "int",
10423+
"static_col": "text",
10424+
},
10425+
)
10426+
def entrypoint(context, *args, **kwargs):
10427+
return pd.DataFrame({"revenue": [1], "static_col": ["x"]})
10428+
"""
10429+
)
10430+
10431+
ctx = Context(
10432+
config=Config(
10433+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
10434+
variables={"metric_col": "revenue"},
10435+
),
10436+
paths=tmp_path,
10437+
)
10438+
assert len(ctx.models) == 1
10439+
10440+
model = ctx.get_model("test_schema.model", raise_if_missing=True)
10441+
10442+
assert model.columns_to_types_ is not None
10443+
assert set(model.columns_to_types_.keys()) == {"revenue", "static_col"}
10444+
assert model.columns_to_types_["revenue"] == exp.DataType.build("int")
10445+
assert model.columns_to_types_["static_col"] == exp.DataType.build("text")
10446+
10447+
1035610448
@time_machine.travel("2020-01-01 00:00:00 UTC")
1035710449
def test_dynamic_date_spine_model(assert_exp_eq):
1035810450
@macro()

tests/core/test_plan_stages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,8 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
692692
finalized_ts=to_timestamp("2023-01-02"),
693693
)
694694

695-
state_reader.get_environment.side_effect = (
696-
lambda name: existing_dev_environment if name == "dev" else existing_prod_environment
695+
state_reader.get_environment.side_effect = lambda name: (
696+
existing_dev_environment if name == "dev" else existing_prod_environment
697697
)
698698
state_reader.get_environments_summary.return_value = [
699699
existing_prod_environment.summary,
@@ -857,8 +857,8 @@ def test_build_plan_stages_restatement_dev_does_not_clear_intervals(
857857
finalized_ts=to_timestamp("2023-01-02"),
858858
)
859859

860-
state_reader.get_environment.side_effect = (
861-
lambda name: existing_dev_environment if name == "dev" else existing_prod_environment
860+
state_reader.get_environment.side_effect = lambda name: (
861+
existing_dev_environment if name == "dev" else existing_prod_environment
862862
)
863863
state_reader.get_environments_summary.return_value = [
864864
existing_prod_environment.summary,

tests/core/test_selector_native.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot)
231231
)
232232

233233
state_reader_mock = mocker.Mock()
234-
state_reader_mock.get_environment.side_effect = (
235-
lambda name: prod_env if name == "prod" else dev_env
234+
state_reader_mock.get_environment.side_effect = lambda name: (
235+
prod_env if name == "prod" else dev_env
236236
)
237237

238238
all_snapshots = {
@@ -875,8 +875,8 @@ def test_select_models_selected_fqns_fallback(mocker: MockerFixture, make_snapsh
875875
)
876876

877877
state_reader_mock = mocker.Mock()
878-
state_reader_mock.get_environment.side_effect = (
879-
lambda name: fallback_env if name == "prod" else None
878+
state_reader_mock.get_environment.side_effect = lambda name: (
879+
fallback_env if name == "prod" else None
880880
)
881881
state_reader_mock.get_snapshots.return_value = {
882882
deleted_model_snapshot.snapshot_id: deleted_model_snapshot,

0 commit comments

Comments
 (0)