diff --git a/config/config.schema.json b/config/config.schema.json new file mode 100644 index 0000000..7c54d04 --- /dev/null +++ b/config/config.schema.json @@ -0,0 +1,83 @@ +{ + "additionalProperties": false, + "properties": { + "countries": { + "default": [ + "BE" + ], + "description": "List of countries to retrieve OSM data for", + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + "retrieve": { + "additionalProperties": false, + "properties": { + "source": { + "default": "geofabrik", + "description": "Retrieval backend for OSM data", + "enum": [ + "geofabrik", + "overpass" + ], + "type": "string" + }, + "primary_name": { + "default": "power", + "description": "Primary OSM feature to retrieve (e.g., 'power')", + "type": "string" + }, + "features": { + "default": [ + "substation", + "line" + ], + "description": "OSM features to retrieve for each country", + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + "force_redownload": { + "default": false, + "description": "Force refresh of cached data in earth-osm", + "type": "boolean" + }, + "mp": { + "default": true, + "description": "Enable multiprocessing in earth-osm", + "type": "boolean" + }, + "stream_backend": { + "default": true, + "description": "Enable streaming backend in earth-osm", + "type": "boolean" + }, + "cache_primary": { + "default": false, + "description": "Enable caching of primary feature data in earth-osm", + "type": "boolean" + }, + "target_date": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional historical date for data retrieval in ISO 8601 datetime format" + } + }, + "description": "Configuration for OSM data retrieval using earth-osm" + } + }, + "title": "ConfigSchema", + "type": "object" +} diff --git a/config/config.yaml b/config/config.yaml index dd19588..962b694 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,15 +1,16 @@ +%YAML 1.1 +--- +# yaml-language-server: $schema=./config.schema.json countries: - - benin - - togo - +- BE retrieve: source: geofabrik primary_name: power features: - - substation - - line + - substation + - line force_redownload: false mp: true stream_backend: true cache_primary: false - target_date: null + target_date: diff --git a/pixi.toml b/pixi.toml index 0c66cc6..682d9d6 100644 --- a/pixi.toml +++ b/pixi.toml @@ -13,7 +13,8 @@ clio-tools = ">=2026.03.30" conda = ">=25.0.0" ipdb = ">=0.13.13" ipykernel = ">=6.29.5" -jsonschema = ">=4.0.0" +pydantic = ">=2.0" +"ruamel.yaml" = ">=0.18" mypy = ">=1.15.0" pytest = ">=8.3.5" python = ">=3.12" @@ -25,3 +26,4 @@ earth-osm = ">=3.0.2" [tasks] test-integration = {cmd = "pytest tests/integration_test.py"} +generate-config = {cmd = "python workflow/scripts/_schema.py"} diff --git a/tests/integration/Snakefile b/tests/integration/Snakefile index becbdaf..1a26443 100644 --- a/tests/integration/Snakefile +++ b/tests/integration/Snakefile @@ -2,37 +2,20 @@ configfile: workflow.source_path("./test_config.yaml") -# Import the module and configure it. -# `snakefile:` specifies the module. It can use file paths and special github(...) / gitlab(...) markers -# `config`: specifies the module configuration. -# `pathvars:` helps you re-wire where the module places files. module grid_builder: pathvars: - # Redirect OSM outputs - osm_out="resources/module/resources/automatic/osm/out", - # Redirect intermediate files that are internal to the module - logs="resources/module/logs", - resources="resources/module/resources", - results="resources/module/results", + resources="resources/grid-builder", + logs="logs/grid-builder", snakefile: "../../workflow/Snakefile" config: config["grid_builder"] -# rename all module rules with a prefix, to avoid naming conflicts. use rule * from grid_builder as grid_builder_* -# Request OSM retrieval outputs from the module rule all: default_target: True input: - expand( - "resources/module/resources/automatic/osm/out/{country}_{feature}.{ext}", - country="belgium", - feature=["substation", "line"], - ext=["csv", "geojson"], - ), - message: - "Retrieve OSM grid data by country and feature." + rules.grid_builder_retrieve_osm_all.input, diff --git a/tests/integration/test_config.yaml b/tests/integration/test_config.yaml index 8d47750..c20aecb 100644 --- a/tests/integration/test_config.yaml +++ b/tests/integration/test_config.yaml @@ -1,6 +1,6 @@ grid_builder: countries: - - belgium + - benin retrieve: source: geofabrik @@ -12,4 +12,4 @@ grid_builder: mp: true stream_backend: true cache_primary: false - target_date: null + target_date: diff --git a/tests/integration_test.py b/tests/integration_test.py index 009e87f..e03e8af 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -31,7 +31,7 @@ def test_interface_file(module_path): "LICENSE", "README.md", "config/config.yaml", - "workflow/internal/config.schema.yaml", + "config/config.schema.json", "tests/integration/Snakefile", ], ) @@ -40,14 +40,6 @@ def test_standard_file_existance(module_path, file): assert Path(module_path / file).exists() -def test_snakemake_all_failure(module_path): - """The snakemake 'all' rule should return an error by default.""" - process = subprocess.run( - "snakemake --cores 1", shell=True, cwd=module_path, capture_output=True - ) - assert "INVALID (missing locally)" in str(process.stderr) - - def test_snakemake_integration_testing(module_path): """Run a light-weight test simulating someone using this module.""" assert subprocess.run( diff --git a/workflow/Snakefile b/workflow/Snakefile index fd9de55..af82649 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -1,6 +1,10 @@ +import sys + import yaml +from snakemake.utils import min_version -from snakemake.utils import min_version, validate +sys.path.insert(0, workflow.basedir) +from scripts._schema import validate_config min_version("9.19") @@ -8,40 +12,24 @@ min_version("9.19") # Define pathvars to expose OSM retrieval outputs for downstream use. pathvars: # OSM retrieval outputs by country and feature - osm_out="/automatic/osm/out", + osm_out="/osm/out", -# Load the example configuration. This will be overridden by users. +# Default configuration file generated from pydantic schema. configfile: workflow.source_path("../config/config.yaml") -# Validate the configuration using the schema file. -validate(config, workflow.source_path("internal/config.schema.yaml")) +config = validate_config(config) # Load internal settings separately so users cannot modify them. with open(workflow.source_path("internal/settings.yaml"), "r") as f: internal = yaml.safe_load(f) -# Add all your includes here. include: "rules/retrieve.smk" -# Add additional files to be delivered alongside the workflow here. -# This is needed e.g. for python code files that are never -# explicitly used as a script or input in a snakemake rule. -# e.g.: workflow.source_path("scripts/_utils.py") - - rule all: default_target: True - output: - "INVALID", - log: - stderr="/all.stderr", - conda: - "envs/shell.yaml" - message: - "ERROR: Invalid `rule all:` call" - shell: - 'echo "This workflow must be called as a snakemake module." > {log.stderr}' + input: + rules.retrieve_osm_all.input, diff --git a/workflow/internal/config.schema.yaml b/workflow/internal/config.schema.yaml deleted file mode 100644 index ab6d383..0000000 --- a/workflow/internal/config.schema.yaml +++ /dev/null @@ -1,56 +0,0 @@ -$schema: "https://json-schema.org/draft/2020-12/schema" -description: "Schema for user-provided configuration files." -type: object -additionalProperties: false -properties: - countries: - description: Country or region slugs used across the workflow. - type: array - minItems: 1 - items: - type: string - retrieve: - type: object - additionalProperties: false - required: - - source - - primary_name - - features - - force_redownload - - mp - - stream_backend - - cache_primary - - target_date - properties: - source: - description: Retrieval backend for OSM data. - type: string - enum: - - geofabrik - - overpass - features: - description: OSM features to retrieve for each country. - type: array - minItems: 1 - items: - type: string - primary_name: - description: Primary OSM theme to retrieve. - type: string - force_redownload: - description: Force refresh of cached inputs inside earth_osm. - type: boolean - mp: - description: Enable multiprocessing in earth_osm. - type: boolean - stream_backend: - description: Use the streaming pipeline when available. - type: boolean - cache_primary: - description: Cache the primary feature during retrieval. - type: boolean - target_date: - description: Optional historical target date. - type: - - string - - "null" diff --git a/workflow/rules/retrieve.smk b/workflow/rules/retrieve.smk index 10bf498..8eaa932 100644 --- a/workflow/rules/retrieve.smk +++ b/workflow/rules/retrieve.smk @@ -2,15 +2,16 @@ # # SPDX-License-Identifier: MIT + rule retrieve_osm: output: csv=expand( - "/automatic/osm/out/{country}_{feature}.csv", + "/osm/out/{country}_{feature}.csv", country="{country}", feature=config["retrieve"]["features"], ), geojson=expand( - "/automatic/osm/out/{country}_{feature}.geojson", + "/osm/out/{country}_{feature}.geojson", country="{country}", feature=config["retrieve"]["features"], ), @@ -36,12 +37,12 @@ rule retrieve_osm: rule retrieve_osm_all: input: csv=expand( - "/automatic/osm/out/{country}_{feature}.csv", + "/osm/out/{country}_{feature}.csv", country=config["countries"], feature=config["retrieve"]["features"], ), geojson=expand( - "/automatic/osm/out/{country}_{feature}.geojson", + "/osm/out/{country}_{feature}.geojson", country=config["countries"], feature=config["retrieve"]["features"], - ) \ No newline at end of file + ), diff --git a/workflow/scripts/_helpers.py b/workflow/scripts/_helpers.py index 772cf48..1b1db79 100644 --- a/workflow/scripts/_helpers.py +++ b/workflow/scripts/_helpers.py @@ -17,8 +17,9 @@ def mock_snakemake( submodule_dir="workflow/submodules/pypsa-eur", **wildcards, ): - """ - This function is expected to be executed from the 'scripts'-directory of ' + """Mock a Snakemake object for testing scripts outside of Snakemake. + + This function is expected to be executed from the 'scripts'-directory of the snakemake project. It returns a snakemake.script.Snakemake object, based on the Snakefile. @@ -155,4 +156,4 @@ def make_accessable(*ios): finally: if user_in_script_dir: os.chdir(script_dir) - return snakemake \ No newline at end of file + return snakemake diff --git a/workflow/scripts/_schema.py b/workflow/scripts/_schema.py new file mode 100644 index 0000000..d68b75f --- /dev/null +++ b/workflow/scripts/_schema.py @@ -0,0 +1,218 @@ +"""Config validation for grid-builder. + +Pydantic models are the single source of truth for config structure, +defaults, schema, and validation. +""" + +import json +import math +from collections.abc import Iterator +from datetime import datetime +from typing import Any, Literal + +from earth_osm.regions import get_all_valid_codes +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap + +_VALID_REGIONS: frozenset[str] = frozenset(get_all_valid_codes()) + + +class ConfigModel(BaseModel): + """Base model with dict-like access for Snakemake compatibility.""" + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def keys(self) -> Iterator[str]: + return iter(type(self).model_fields.keys()) + + def values(self) -> Iterator[Any]: + return (getattr(self, k) for k in type(self).model_fields.keys()) + + def items(self) -> Iterator[tuple[str, Any]]: + return ((k, getattr(self, k)) for k in type(self).model_fields.keys()) + + +class RetrieveConfig(ConfigModel): + model_config = ConfigDict(extra="forbid") + + source: Literal["geofabrik", "overpass"] = Field( + "geofabrik", description="Retrieval backend for OSM data" + ) + primary_name: str = Field( + "power", description="Primary OSM feature to retrieve (e.g., 'power')" + ) + features: list[str] = Field( + default=["substation", "line"], + description="OSM features to retrieve for each country", + min_length=1, + ) + force_redownload: bool = Field( + False, description="Force refresh of cached data in earth-osm" + ) + mp: bool = Field(True, description="Enable multiprocessing in earth-osm") + stream_backend: bool = Field( + True, description="Enable streaming backend in earth-osm" + ) + cache_primary: bool = Field( + False, description="Enable caching of primary feature data in earth-osm" + ) + target_date: datetime | None = Field( + None, + description="Optional historical date for data retrieval in ISO 8601 datetime format", + ) + + +class ConfigSchema(ConfigModel): + model_config = ConfigDict(extra="forbid") + + countries: list[str] = Field( + default=["BE"], + description="List of countries to retrieve OSM data for", + min_length=1, + ) + + @field_validator("countries") + @classmethod + def validate_country_identifiers(cls, v: list[str]) -> list[str]: + invalid = [c for c in v if c not in _VALID_REGIONS] + if invalid: + raise ValueError( + f"Unknown country identifier(s): {invalid}. " + "Use an English name (e.g. 'benin') or ISO 3166-1 alpha-2 code (e.g. 'BE')." + ) + return v + + retrieve: RetrieveConfig = Field( + default_factory=RetrieveConfig, + description="Configuration for OSM data retrieval using earth-osm", + ) + + +def validate_config(config: dict) -> ConfigSchema: + """Validate config dict against schema.""" + return ConfigSchema(**config) + + +def generate_config_defaults(path: str = "config/config.yaml") -> dict: + """Generate config defaults YAML file and return the defaults dict.""" + config = validate_config({}) + defaults = config.model_dump() + + yaml_writer = YAML() + yaml_writer.version = (1, 1) + yaml_writer.default_flow_style = False + yaml_writer.width = 4096 + yaml_writer.indent(mapping=2, sequence=2, offset=0) + + def str_representer(dumper, data): + TAG = "tag:yaml.org,2002:str" + if "\n" in data: + return dumper.represent_scalar(TAG, data, style="|") + if data == "" or any(c in data for c in ":{}[]&*#?|-<>=!%@"): + return dumper.represent_scalar(TAG, data, style='"') + return dumper.represent_scalar(TAG, data, style="") + + yaml_writer.representer.add_representer(str, str_representer) + + data = CommentedMap() + data.yaml_set_start_comment("yaml-language-server: $schema=./config.schema.json") + + for key, value in defaults.items(): + data[key] = value + + with open(path, "w") as f: + yaml_writer.dump(data, f) + + return defaults + + +def generate_config_schema(path: str = "config/config.schema.json") -> dict: + """Generate JSON schema file and return the schema dict.""" + + def resolve_refs(obj, defs): + if isinstance(obj, dict): + if "$ref" in obj: + ref_name = obj["$ref"].split("/")[-1] + if ref_name in defs: + resolved = resolve_refs(defs[ref_name].copy(), defs) + if "description" in obj and "description" not in resolved: + resolved["description"] = obj["description"] + return resolved + return {k: resolve_refs(v, defs) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve_refs(item, defs) for item in obj] + return obj + + def sanitize_for_json(obj): + if isinstance(obj, dict): + return {k: sanitize_for_json(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [sanitize_for_json(v) for v in obj] + elif isinstance(obj, float) and math.isinf(obj): + return None + return obj + + def remove_nested_titles(obj, is_root=True): + if isinstance(obj, dict): + result = {} + for k, v in obj.items(): + if k == "title" and not is_root: + continue + result[k] = remove_nested_titles(v, is_root=False) + return result + elif isinstance(obj, list): + return [remove_nested_titles(item, is_root=False) for item in obj] + return obj + + def remove_object_type(obj, is_root=True): + if isinstance(obj, dict): + result = {} + for k, v in obj.items(): + if ( + k == "type" + and v == "object" + and not is_root + and "properties" in obj + ): + continue + result[k] = remove_object_type(v, is_root=False) + return result + elif isinstance(obj, list): + return [remove_object_type(item, is_root=False) for item in obj] + return obj + + config = validate_config({}) + schema = config.model_json_schema() + defs = schema.pop("$defs", {}) + schema = resolve_refs(schema, defs) + schema = sanitize_for_json(schema) + schema = remove_nested_titles(schema) + schema = remove_object_type(schema) + + with open(path, "w") as f: + json.dump(schema, f, indent=2) + f.write("\n") + + return schema + + +__all__ = [ + "ConfigSchema", + "validate_config", + "generate_config_defaults", + "generate_config_schema", + "ValidationError", +] + + +if __name__ == "__main__": + generate_config_defaults() + generate_config_schema() diff --git a/workflow/scripts/retrieve_osm.py b/workflow/scripts/retrieve_osm.py index 476b732..830668d 100644 --- a/workflow/scripts/retrieve_osm.py +++ b/workflow/scripts/retrieve_osm.py @@ -6,6 +6,7 @@ import logging import os +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -31,10 +32,9 @@ def retrieve_osm_data( mp: bool = True, stream_backend: bool = True, cache_primary: bool = False, - target_date: str | None = None, + target_date: datetime | None = None, ) -> None: - """ - Retrieve OSM data for a single country using earth_osm. + """Retrieve OSM data for a single country using earth_osm. Parameters ---------- @@ -56,7 +56,7 @@ def retrieve_osm_data( Enable streaming backend. Default is True. cache_primary : bool, optional Cache primary data. Default is False. - target_date : str | None, optional + target_date : datetime.datetime | None, optional Target date for historical data. Default is None. """ logger.info(f"Retrieving OSM data for {country} with features: {features}") @@ -82,20 +82,37 @@ def retrieve_osm_data( logger.info(f"Successfully retrieved OSM data for {country}") +def parse_target_date(target_date: datetime | str | None) -> datetime | None: + """Convert a YAML/snaked config value into the datetime earth_osm expects. + + Parameters + ---------- + target_date : datetime.datetime | str | None + The target date as a datetime object, ISO 8601 string, or None. + + Returns: + ------- + datetime.datetime | None + The target date as a datetime object, or None if not provided. + """ + if target_date is None or isinstance(target_date, datetime): + return target_date + return datetime.fromisoformat(target_date) + + if __name__ == "__main__": if "snakemake" not in globals(): from workflow.scripts._helpers import mock_snakemake - snakemake = mock_snakemake( - "retrieve_osm", - country="benin", - ) + snakemake = mock_snakemake("retrieve_osm", country="benin") # Extract parameters country = snakemake.wildcards.country features = list(snakemake.params.features) base_dir = str(Path(snakemake.output.geojson[0]).parent.parent) + target_date = parse_target_date(snakemake.params.target_date) + # Call main function retrieve_osm_data( country=country, @@ -107,5 +124,5 @@ def retrieve_osm_data( mp=snakemake.params.mp, stream_backend=snakemake.params.stream_backend, cache_primary=snakemake.params.cache_primary, - target_date=snakemake.params.target_date, - ) \ No newline at end of file + target_date=target_date, + )