From 0760b313bc431ef8fcaaea8cbc063b6b1e1c26eb Mon Sep 17 00:00:00 2001 From: lkstrp Date: Fri, 8 May 2026 11:31:55 +0200 Subject: [PATCH 1/6] chore: use pydantic as schema provider --- config/config.schema.json | 83 ++++++++++++ config/config.yaml | 14 +- pixi.toml | 4 +- tests/integration_test.py | 2 +- workflow/Snakefile | 10 +- workflow/internal/config.schema.yaml | 56 -------- workflow/scripts/_schema.py | 194 +++++++++++++++++++++++++++ 7 files changed, 296 insertions(+), 67 deletions(-) create mode 100644 config/config.schema.json delete mode 100644 workflow/internal/config.schema.yaml create mode 100644 workflow/scripts/_schema.py diff --git a/config/config.schema.json b/config/config.schema.json new file mode 100644 index 0000000..58e5604 --- /dev/null +++ b/config/config.schema.json @@ -0,0 +1,83 @@ +{ + "additionalProperties": false, + "properties": { + "countries": { + "default": [ + "benin", + "togo" + ], + "description": "tbd", + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + "retrieve": { + "additionalProperties": false, + "properties": { + "source": { + "default": "geofabrik", + "description": "tbd", + "enum": [ + "geofabrik", + "overpass" + ], + "type": "string" + }, + "primary_name": { + "default": "power", + "description": "tbd", + "type": "string" + }, + "features": { + "default": [ + "substation", + "line" + ], + "description": "tbd", + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + "force_redownload": { + "default": false, + "description": "tbd", + "type": "boolean" + }, + "mp": { + "default": true, + "description": "tbd", + "type": "boolean" + }, + "stream_backend": { + "default": true, + "description": "tbd", + "type": "boolean" + }, + "cache_primary": { + "default": false, + "description": "tbd", + "type": "boolean" + }, + "target_date": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "tbd" + } + }, + "description": "tbd" + } + }, + "title": "ConfigSchema", + "type": "object" +} diff --git a/config/config.yaml b/config/config.yaml index dd19588..b7b71cd 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,15 +1,17 @@ +%YAML 1.1 +--- +# yaml-language-server: $schema=./config.schema.json countries: - - benin - - togo - +- benin +- togo retrieve: source: geofabrik primary_name: power features: - - substation - - line + - substation + - line force_redownload: false mp: true stream_backend: true cache_primary: false - target_date: null + target_date: diff --git a/pixi.toml b/pixi.toml index 0c66cc6..682d9d6 100644 --- a/pixi.toml +++ b/pixi.toml @@ -13,7 +13,8 @@ clio-tools = ">=2026.03.30" conda = ">=25.0.0" ipdb = ">=0.13.13" ipykernel = ">=6.29.5" -jsonschema = ">=4.0.0" +pydantic = ">=2.0" +"ruamel.yaml" = ">=0.18" mypy = ">=1.15.0" pytest = ">=8.3.5" python = ">=3.12" @@ -25,3 +26,4 @@ earth-osm = ">=3.0.2" [tasks] test-integration = {cmd = "pytest tests/integration_test.py"} +generate-config = {cmd = "python workflow/scripts/_schema.py"} diff --git a/tests/integration_test.py b/tests/integration_test.py index 009e87f..8903cfa 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -31,7 +31,7 @@ def test_interface_file(module_path): "LICENSE", "README.md", "config/config.yaml", - "workflow/internal/config.schema.yaml", + "config/config.schema.json", "tests/integration/Snakefile", ], ) diff --git a/workflow/Snakefile b/workflow/Snakefile index fd9de55..9266a38 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -1,6 +1,10 @@ +import sys + import yaml +from snakemake.utils import min_version -from snakemake.utils import min_version, validate +sys.path.insert(0, workflow.basedir) +from scripts._schema import validate_config min_version("9.19") @@ -15,8 +19,8 @@ pathvars: configfile: workflow.source_path("../config/config.yaml") -# Validate the configuration using the schema file. -validate(config, workflow.source_path("internal/config.schema.yaml")) +# Validate the configuration using the Pydantic schema. +config = validate_config(config) # Load internal settings separately so users cannot modify them. with open(workflow.source_path("internal/settings.yaml"), "r") as f: diff --git a/workflow/internal/config.schema.yaml b/workflow/internal/config.schema.yaml deleted file mode 100644 index ab6d383..0000000 --- a/workflow/internal/config.schema.yaml +++ /dev/null @@ -1,56 +0,0 @@ -$schema: "https://json-schema.org/draft/2020-12/schema" -description: "Schema for user-provided configuration files." -type: object -additionalProperties: false -properties: - countries: - description: Country or region slugs used across the workflow. - type: array - minItems: 1 - items: - type: string - retrieve: - type: object - additionalProperties: false - required: - - source - - primary_name - - features - - force_redownload - - mp - - stream_backend - - cache_primary - - target_date - properties: - source: - description: Retrieval backend for OSM data. - type: string - enum: - - geofabrik - - overpass - features: - description: OSM features to retrieve for each country. - type: array - minItems: 1 - items: - type: string - primary_name: - description: Primary OSM theme to retrieve. - type: string - force_redownload: - description: Force refresh of cached inputs inside earth_osm. - type: boolean - mp: - description: Enable multiprocessing in earth_osm. - type: boolean - stream_backend: - description: Use the streaming pipeline when available. - type: boolean - cache_primary: - description: Cache the primary feature during retrieval. - type: boolean - target_date: - description: Optional historical target date. - type: - - string - - "null" diff --git a/workflow/scripts/_schema.py b/workflow/scripts/_schema.py new file mode 100644 index 0000000..e7f1283 --- /dev/null +++ b/workflow/scripts/_schema.py @@ -0,0 +1,194 @@ +"""Config validation for grid-builder. + +Pydantic models are the single source of truth for config structure, +defaults, schema, and validation. +""" + +import json +import math +import re +from collections.abc import Iterator +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap + + +class ConfigModel(BaseModel): + """Base model with dict-like access for Snakemake compatibility.""" + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def keys(self) -> Iterator[str]: + return iter(self.model_fields.keys()) + + def values(self) -> Iterator[Any]: + return (getattr(self, k) for k in self.model_fields.keys()) + + def items(self) -> Iterator[tuple[str, Any]]: + return ((k, getattr(self, k)) for k in self.model_fields.keys()) + + +class RetrieveConfig(ConfigModel): + model_config = ConfigDict(extra="forbid") + + source: Literal["geofabrik", "overpass"] = Field( + "geofabrik", description="tbd" + ) + primary_name: str = Field("power", description="tbd") + features: list[str] = Field( + default=["substation", "line"], + description="tbd", + min_length=1, + ) + force_redownload: bool = Field(False, description="tbd") + mp: bool = Field(True, description="tbd") + stream_backend: bool = Field(True, description="tbd") + cache_primary: bool = Field(False, description="tbd") + target_date: str | None = Field(None, description="tbd") + + +class ConfigSchema(ConfigModel): + model_config = ConfigDict(extra="forbid") + + countries: list[str] = Field( + default=["benin", "togo"], + description="tbd", + min_length=1, + ) + retrieve: RetrieveConfig = Field( + default_factory=RetrieveConfig, + description="tbd", + ) + + +def validate_config(config: dict) -> ConfigSchema: + """Validate config dict against schema.""" + return ConfigSchema(**config) + + +def generate_config_defaults(path: str = "config/config.yaml") -> dict: + """Generate config defaults YAML file and return the defaults dict.""" + config = validate_config({}) + defaults = config.model_dump() + + yaml_writer = YAML() + yaml_writer.version = (1, 1) + yaml_writer.default_flow_style = False + yaml_writer.width = 4096 + yaml_writer.indent(mapping=2, sequence=2, offset=0) + + def str_representer(dumper, data): + TAG = "tag:yaml.org,2002:str" + if "\n" in data: + return dumper.represent_scalar(TAG, data, style="|") + if data == "" or any(c in data for c in ":{}[]&*#?|-<>=!%@"): + return dumper.represent_scalar(TAG, data, style='"') + return dumper.represent_scalar(TAG, data, style="") + + yaml_writer.representer.add_representer(str, str_representer) + + data = CommentedMap() + data.yaml_set_start_comment( + "yaml-language-server: $schema=./config.schema.json" + ) + + for key, value in defaults.items(): + data[key] = value + + with open(path, "w") as f: + yaml_writer.dump(data, f) + + return defaults + + +def generate_config_schema(path: str = "config/config.schema.json") -> dict: + """Generate JSON schema file and return the schema dict.""" + + def resolve_refs(obj, defs): + if isinstance(obj, dict): + if "$ref" in obj: + ref_name = obj["$ref"].split("/")[-1] + if ref_name in defs: + resolved = resolve_refs(defs[ref_name].copy(), defs) + if "description" in obj and "description" not in resolved: + resolved["description"] = obj["description"] + return resolved + return {k: resolve_refs(v, defs) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve_refs(item, defs) for item in obj] + return obj + + def sanitize_for_json(obj): + if isinstance(obj, dict): + return {k: sanitize_for_json(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [sanitize_for_json(v) for v in obj] + elif isinstance(obj, float) and math.isinf(obj): + return None + return obj + + def remove_nested_titles(obj, is_root=True): + if isinstance(obj, dict): + result = {} + for k, v in obj.items(): + if k == "title" and not is_root: + continue + result[k] = remove_nested_titles(v, is_root=False) + return result + elif isinstance(obj, list): + return [remove_nested_titles(item, is_root=False) for item in obj] + return obj + + def remove_object_type(obj, is_root=True): + if isinstance(obj, dict): + result = {} + for k, v in obj.items(): + if ( + k == "type" + and v == "object" + and not is_root + and "properties" in obj + ): + continue + result[k] = remove_object_type(v, is_root=False) + return result + elif isinstance(obj, list): + return [remove_object_type(item, is_root=False) for item in obj] + return obj + + config = validate_config({}) + schema = config.model_json_schema() + defs = schema.pop("$defs", {}) + schema = resolve_refs(schema, defs) + schema = sanitize_for_json(schema) + schema = remove_nested_titles(schema) + schema = remove_object_type(schema) + + with open(path, "w") as f: + json.dump(schema, f, indent=2) + f.write("\n") + + return schema + + +__all__ = [ + "ConfigSchema", + "validate_config", + "generate_config_defaults", + "generate_config_schema", + "ValidationError", +] + + +if __name__ == "__main__": + generate_config_defaults() + generate_config_schema() From 2bc74ee2cd238dbfbf38ee501dd821156fea1f9d Mon Sep 17 00:00:00 2001 From: Bobby Xiong Date: Tue, 12 May 2026 13:43:10 +0200 Subject: [PATCH 2/6] fix: pydantic model_fields deprecation warning by accessing model_fields on the class instead of insance --- workflow/scripts/_schema.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/workflow/scripts/_schema.py b/workflow/scripts/_schema.py index e7f1283..0558988 100644 --- a/workflow/scripts/_schema.py +++ b/workflow/scripts/_schema.py @@ -28,13 +28,13 @@ def get(self, key: str, default: Any = None) -> Any: return getattr(self, key, default) def keys(self) -> Iterator[str]: - return iter(self.model_fields.keys()) + return iter(type(self).model_fields.keys()) def values(self) -> Iterator[Any]: - return (getattr(self, k) for k in self.model_fields.keys()) + return (getattr(self, k) for k in type(self).model_fields.keys()) def items(self) -> Iterator[tuple[str, Any]]: - return ((k, getattr(self, k)) for k in self.model_fields.keys()) + return ((k, getattr(self, k)) for k in type(self).model_fields.keys()) class RetrieveConfig(ConfigModel): From b16501dad4dee3055fd2c035c1b32d0194874e3a Mon Sep 17 00:00:00 2001 From: Bobby Xiong Date: Tue, 12 May 2026 14:32:59 +0200 Subject: [PATCH 3/6] Add schema descriptions, add config.countries validator. Pre-commit fixes. --- config/config.schema.json | 24 +++++++------- config/config.yaml | 3 +- workflow/rules/retrieve.smk | 3 +- workflow/scripts/_helpers.py | 7 ++-- workflow/scripts/_schema.py | 56 +++++++++++++++++++++++--------- workflow/scripts/retrieve_osm.py | 37 +++++++++++++++------ 6 files changed, 86 insertions(+), 44 deletions(-) diff --git a/config/config.schema.json b/config/config.schema.json index 58e5604..7c54d04 100644 --- a/config/config.schema.json +++ b/config/config.schema.json @@ -3,10 +3,9 @@ "properties": { "countries": { "default": [ - "benin", - "togo" + "BE" ], - "description": "tbd", + "description": "List of countries to retrieve OSM data for", "items": { "type": "string" }, @@ -18,7 +17,7 @@ "properties": { "source": { "default": "geofabrik", - "description": "tbd", + "description": "Retrieval backend for OSM data", "enum": [ "geofabrik", "overpass" @@ -27,7 +26,7 @@ }, "primary_name": { "default": "power", - "description": "tbd", + "description": "Primary OSM feature to retrieve (e.g., 'power')", "type": "string" }, "features": { @@ -35,7 +34,7 @@ "substation", "line" ], - "description": "tbd", + "description": "OSM features to retrieve for each country", "items": { "type": "string" }, @@ -44,27 +43,28 @@ }, "force_redownload": { "default": false, - "description": "tbd", + "description": "Force refresh of cached data in earth-osm", "type": "boolean" }, "mp": { "default": true, - "description": "tbd", + "description": "Enable multiprocessing in earth-osm", "type": "boolean" }, "stream_backend": { "default": true, - "description": "tbd", + "description": "Enable streaming backend in earth-osm", "type": "boolean" }, "cache_primary": { "default": false, - "description": "tbd", + "description": "Enable caching of primary feature data in earth-osm", "type": "boolean" }, "target_date": { "anyOf": [ { + "format": "date-time", "type": "string" }, { @@ -72,10 +72,10 @@ } ], "default": null, - "description": "tbd" + "description": "Optional historical date for data retrieval in ISO 8601 datetime format" } }, - "description": "tbd" + "description": "Configuration for OSM data retrieval using earth-osm" } }, "title": "ConfigSchema", diff --git a/config/config.yaml b/config/config.yaml index b7b71cd..962b694 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,8 +2,7 @@ --- # yaml-language-server: $schema=./config.schema.json countries: -- benin -- togo +- BE retrieve: source: geofabrik primary_name: power diff --git a/workflow/rules/retrieve.smk b/workflow/rules/retrieve.smk index 10bf498..63b2059 100644 --- a/workflow/rules/retrieve.smk +++ b/workflow/rules/retrieve.smk @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MIT + rule retrieve_osm: output: csv=expand( @@ -44,4 +45,4 @@ rule retrieve_osm_all: "/automatic/osm/out/{country}_{feature}.geojson", country=config["countries"], feature=config["retrieve"]["features"], - ) \ No newline at end of file + ), diff --git a/workflow/scripts/_helpers.py b/workflow/scripts/_helpers.py index 772cf48..1b1db79 100644 --- a/workflow/scripts/_helpers.py +++ b/workflow/scripts/_helpers.py @@ -17,8 +17,9 @@ def mock_snakemake( submodule_dir="workflow/submodules/pypsa-eur", **wildcards, ): - """ - This function is expected to be executed from the 'scripts'-directory of ' + """Mock a Snakemake object for testing scripts outside of Snakemake. + + This function is expected to be executed from the 'scripts'-directory of the snakemake project. It returns a snakemake.script.Snakemake object, based on the Snakefile. @@ -155,4 +156,4 @@ def make_accessable(*ios): finally: if user_in_script_dir: os.chdir(script_dir) - return snakemake \ No newline at end of file + return snakemake diff --git a/workflow/scripts/_schema.py b/workflow/scripts/_schema.py index 0558988..d68b75f 100644 --- a/workflow/scripts/_schema.py +++ b/workflow/scripts/_schema.py @@ -6,14 +6,17 @@ import json import math -import re from collections.abc import Iterator +from datetime import datetime from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, ValidationError +from earth_osm.regions import get_all_valid_codes +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap +_VALID_REGIONS: frozenset[str] = frozenset(get_all_valid_codes()) + class ConfigModel(BaseModel): """Base model with dict-like access for Snakemake compatibility.""" @@ -41,32 +44,55 @@ class RetrieveConfig(ConfigModel): model_config = ConfigDict(extra="forbid") source: Literal["geofabrik", "overpass"] = Field( - "geofabrik", description="tbd" + "geofabrik", description="Retrieval backend for OSM data" + ) + primary_name: str = Field( + "power", description="Primary OSM feature to retrieve (e.g., 'power')" ) - primary_name: str = Field("power", description="tbd") features: list[str] = Field( default=["substation", "line"], - description="tbd", + description="OSM features to retrieve for each country", min_length=1, ) - force_redownload: bool = Field(False, description="tbd") - mp: bool = Field(True, description="tbd") - stream_backend: bool = Field(True, description="tbd") - cache_primary: bool = Field(False, description="tbd") - target_date: str | None = Field(None, description="tbd") + force_redownload: bool = Field( + False, description="Force refresh of cached data in earth-osm" + ) + mp: bool = Field(True, description="Enable multiprocessing in earth-osm") + stream_backend: bool = Field( + True, description="Enable streaming backend in earth-osm" + ) + cache_primary: bool = Field( + False, description="Enable caching of primary feature data in earth-osm" + ) + target_date: datetime | None = Field( + None, + description="Optional historical date for data retrieval in ISO 8601 datetime format", + ) class ConfigSchema(ConfigModel): model_config = ConfigDict(extra="forbid") countries: list[str] = Field( - default=["benin", "togo"], - description="tbd", + default=["BE"], + description="List of countries to retrieve OSM data for", min_length=1, ) + + @field_validator("countries") + @classmethod + def validate_country_identifiers(cls, v: list[str]) -> list[str]: + invalid = [c for c in v if c not in _VALID_REGIONS] + if invalid: + raise ValueError( + f"Unknown country identifier(s): {invalid}. " + "Use an English name (e.g. 'benin') or ISO 3166-1 alpha-2 code (e.g. 'BE')." + ) + return v + retrieve: RetrieveConfig = Field( default_factory=RetrieveConfig, - description="tbd", + description="Configuration for OSM data retrieval using earth-osm", ) @@ -97,9 +123,7 @@ def str_representer(dumper, data): yaml_writer.representer.add_representer(str, str_representer) data = CommentedMap() - data.yaml_set_start_comment( - "yaml-language-server: $schema=./config.schema.json" - ) + data.yaml_set_start_comment("yaml-language-server: $schema=./config.schema.json") for key, value in defaults.items(): data[key] = value diff --git a/workflow/scripts/retrieve_osm.py b/workflow/scripts/retrieve_osm.py index 476b732..830668d 100644 --- a/workflow/scripts/retrieve_osm.py +++ b/workflow/scripts/retrieve_osm.py @@ -6,6 +6,7 @@ import logging import os +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -31,10 +32,9 @@ def retrieve_osm_data( mp: bool = True, stream_backend: bool = True, cache_primary: bool = False, - target_date: str | None = None, + target_date: datetime | None = None, ) -> None: - """ - Retrieve OSM data for a single country using earth_osm. + """Retrieve OSM data for a single country using earth_osm. Parameters ---------- @@ -56,7 +56,7 @@ def retrieve_osm_data( Enable streaming backend. Default is True. cache_primary : bool, optional Cache primary data. Default is False. - target_date : str | None, optional + target_date : datetime.datetime | None, optional Target date for historical data. Default is None. """ logger.info(f"Retrieving OSM data for {country} with features: {features}") @@ -82,20 +82,37 @@ def retrieve_osm_data( logger.info(f"Successfully retrieved OSM data for {country}") +def parse_target_date(target_date: datetime | str | None) -> datetime | None: + """Convert a YAML/snaked config value into the datetime earth_osm expects. + + Parameters + ---------- + target_date : datetime.datetime | str | None + The target date as a datetime object, ISO 8601 string, or None. + + Returns: + ------- + datetime.datetime | None + The target date as a datetime object, or None if not provided. + """ + if target_date is None or isinstance(target_date, datetime): + return target_date + return datetime.fromisoformat(target_date) + + if __name__ == "__main__": if "snakemake" not in globals(): from workflow.scripts._helpers import mock_snakemake - snakemake = mock_snakemake( - "retrieve_osm", - country="benin", - ) + snakemake = mock_snakemake("retrieve_osm", country="benin") # Extract parameters country = snakemake.wildcards.country features = list(snakemake.params.features) base_dir = str(Path(snakemake.output.geojson[0]).parent.parent) + target_date = parse_target_date(snakemake.params.target_date) + # Call main function retrieve_osm_data( country=country, @@ -107,5 +124,5 @@ def retrieve_osm_data( mp=snakemake.params.mp, stream_backend=snakemake.params.stream_backend, cache_primary=snakemake.params.cache_primary, - target_date=snakemake.params.target_date, - ) \ No newline at end of file + target_date=target_date, + ) From b9291a6059dd26d44dd0138bd5b9ab8f7f662a98 Mon Sep 17 00:00:00 2001 From: Bobby Xiong Date: Tue, 12 May 2026 14:40:18 +0200 Subject: [PATCH 4/6] Update and clean up Snakefile. --- workflow/Snakefile | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/workflow/Snakefile b/workflow/Snakefile index 9266a38..7417149 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -15,11 +15,9 @@ pathvars: osm_out="/automatic/osm/out", -# Load the example configuration. This will be overridden by users. +# Default configuration file generated from pydantic schema. configfile: workflow.source_path("../config/config.yaml") - -# Validate the configuration using the Pydantic schema. config = validate_config(config) # Load internal settings separately so users cannot modify them. @@ -27,25 +25,10 @@ with open(workflow.source_path("internal/settings.yaml"), "r") as f: internal = yaml.safe_load(f) -# Add all your includes here. include: "rules/retrieve.smk" -# Add additional files to be delivered alongside the workflow here. -# This is needed e.g. for python code files that are never -# explicitly used as a script or input in a snakemake rule. -# e.g.: workflow.source_path("scripts/_utils.py") - - rule all: default_target: True - output: - "INVALID", - log: - stderr="/all.stderr", - conda: - "envs/shell.yaml" - message: - "ERROR: Invalid `rule all:` call" - shell: - 'echo "This workflow must be called as a snakemake module." > {log.stderr}' + input: + rules.retrieve_osm_all.input, From 1a38a7af3d1076f05bbf1a945c85b6e2b07cba6d Mon Sep 17 00:00:00 2001 From: Bobby Xiong Date: Tue, 12 May 2026 15:11:42 +0200 Subject: [PATCH 5/6] Update and clean up Snakefiles and integration_test. --- tests/integration/Snakefile | 23 +++-------------------- tests/integration/test_config.yaml | 4 ++-- tests/integration_test.py | 8 -------- workflow/Snakefile | 2 +- workflow/rules/retrieve.smk | 8 ++++---- 5 files changed, 10 insertions(+), 35 deletions(-) diff --git a/tests/integration/Snakefile b/tests/integration/Snakefile index becbdaf..1a26443 100644 --- a/tests/integration/Snakefile +++ b/tests/integration/Snakefile @@ -2,37 +2,20 @@ configfile: workflow.source_path("./test_config.yaml") -# Import the module and configure it. -# `snakefile:` specifies the module. It can use file paths and special github(...) / gitlab(...) markers -# `config`: specifies the module configuration. -# `pathvars:` helps you re-wire where the module places files. module grid_builder: pathvars: - # Redirect OSM outputs - osm_out="resources/module/resources/automatic/osm/out", - # Redirect intermediate files that are internal to the module - logs="resources/module/logs", - resources="resources/module/resources", - results="resources/module/results", + resources="resources/grid-builder", + logs="logs/grid-builder", snakefile: "../../workflow/Snakefile" config: config["grid_builder"] -# rename all module rules with a prefix, to avoid naming conflicts. use rule * from grid_builder as grid_builder_* -# Request OSM retrieval outputs from the module rule all: default_target: True input: - expand( - "resources/module/resources/automatic/osm/out/{country}_{feature}.{ext}", - country="belgium", - feature=["substation", "line"], - ext=["csv", "geojson"], - ), - message: - "Retrieve OSM grid data by country and feature." + rules.grid_builder_retrieve_osm_all.input, diff --git a/tests/integration/test_config.yaml b/tests/integration/test_config.yaml index 8d47750..c20aecb 100644 --- a/tests/integration/test_config.yaml +++ b/tests/integration/test_config.yaml @@ -1,6 +1,6 @@ grid_builder: countries: - - belgium + - benin retrieve: source: geofabrik @@ -12,4 +12,4 @@ grid_builder: mp: true stream_backend: true cache_primary: false - target_date: null + target_date: diff --git a/tests/integration_test.py b/tests/integration_test.py index 8903cfa..e03e8af 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -40,14 +40,6 @@ def test_standard_file_existance(module_path, file): assert Path(module_path / file).exists() -def test_snakemake_all_failure(module_path): - """The snakemake 'all' rule should return an error by default.""" - process = subprocess.run( - "snakemake --cores 1", shell=True, cwd=module_path, capture_output=True - ) - assert "INVALID (missing locally)" in str(process.stderr) - - def test_snakemake_integration_testing(module_path): """Run a light-weight test simulating someone using this module.""" assert subprocess.run( diff --git a/workflow/Snakefile b/workflow/Snakefile index 7417149..476ea14 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -12,7 +12,7 @@ min_version("9.19") # Define pathvars to expose OSM retrieval outputs for downstream use. pathvars: # OSM retrieval outputs by country and feature - osm_out="/automatic/osm/out", + osm_out="/osm/out", # Default configuration file generated from pydantic schema. diff --git a/workflow/rules/retrieve.smk b/workflow/rules/retrieve.smk index 63b2059..8eaa932 100644 --- a/workflow/rules/retrieve.smk +++ b/workflow/rules/retrieve.smk @@ -6,12 +6,12 @@ rule retrieve_osm: output: csv=expand( - "/automatic/osm/out/{country}_{feature}.csv", + "/osm/out/{country}_{feature}.csv", country="{country}", feature=config["retrieve"]["features"], ), geojson=expand( - "/automatic/osm/out/{country}_{feature}.geojson", + "/osm/out/{country}_{feature}.geojson", country="{country}", feature=config["retrieve"]["features"], ), @@ -37,12 +37,12 @@ rule retrieve_osm: rule retrieve_osm_all: input: csv=expand( - "/automatic/osm/out/{country}_{feature}.csv", + "/osm/out/{country}_{feature}.csv", country=config["countries"], feature=config["retrieve"]["features"], ), geojson=expand( - "/automatic/osm/out/{country}_{feature}.geojson", + "/osm/out/{country}_{feature}.geojson", country=config["countries"], feature=config["retrieve"]["features"], ), From 4437685a6ea8e32092fd63d1b2d875ebee397f1f Mon Sep 17 00:00:00 2001 From: Bobby Xiong Date: Tue, 12 May 2026 15:12:46 +0200 Subject: [PATCH 6/6] Snakefile formatting. --- workflow/Snakefile | 1 + 1 file changed, 1 insertion(+) diff --git a/workflow/Snakefile b/workflow/Snakefile index 476ea14..af82649 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -18,6 +18,7 @@ pathvars: # Default configuration file generated from pydantic schema. configfile: workflow.source_path("../config/config.yaml") + config = validate_config(config) # Load internal settings separately so users cannot modify them.