From 1685dec7b78f1b80f8248ed0fc251623cc51c5b7 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Tue, 21 Apr 2026 18:13:08 +0530 Subject: [PATCH 1/4] implement get /run/{id} --- src/config.py | 4 + src/config.toml | 3 + src/database/flows.py | 2 +- src/database/runs.py | 179 ++++++++++++ src/routers/openml/runs.py | 140 +++++++++- src/schemas/runs.py | 93 ++++++- tests/routers/openml/runs_get_test.py | 383 ++++++++++++++++++++++++++ 7 files changed, 799 insertions(+), 5 deletions(-) create mode 100644 tests/routers/openml/runs_get_test.py diff --git a/src/config.py b/src/config.py index ffc4fb89..011a5624 100644 --- a/src/config.py +++ b/src/config.py @@ -54,6 +54,10 @@ def load_routing_configuration(file: Path = _config_file) -> TomlTable: return typing.cast("TomlTable", _load_configuration(file)["routing"]) +def load_run_configuration(file: Path = _config_file) -> TomlTable: + return typing.cast("TomlTable", _load_configuration(file).get("run", {})) + + @functools.cache def load_database_configuration(file: Path = _config_file) -> TomlTable: configuration = _load_configuration(file) diff --git a/src/config.toml b/src/config.toml index 384067d7..b56a3e0e 100644 --- a/src/config.toml +++ b/src/config.toml @@ -37,3 +37,6 @@ database="openml" [routing] minio_url="http://minio:9000/" server_url="http://php-api:80/" + +[run] +evaluation_engine_ids = [1] diff --git a/src/database/flows.py b/src/database/flows.py index 79bb6e5b..ec04f6a6 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -73,7 +73,7 @@ async def get(id_: int, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( text( """ - SELECT *, uploadDate as upload_date + SELECT *, uploadDate as upload_date, fullName AS full_name FROM implementation WHERE id = :flow_id """, diff --git a/src/database/runs.py b/src/database/runs.py index acf7a532..5f7f2afb 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -22,6 +22,185 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool: return bool(row.one_or_none()) +async def get(run_id: int, expdb: AsyncConnection) -> Row | None: + """Fetch the core run row from the `run` table. + + Returns the row if found, or None if no run with `run_id` exists. + The `error_message` column is NULL when the run completed without errors. + """ + row = await expdb.execute( + text( + """ + SELECT `rid`, `uploader`, `setup`, `task_id`, `error_message` + FROM `run` + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return row.one_or_none() + + +async def get_uploader_name(uploader_id: int, userdb: AsyncConnection) -> str | None: + """Fetch the display name of a user from the openml database. + + Queries the `users` table in the separate openml DB and concatenates + first_name + ' ' + last_name. Returns None if the user does not exist. + """ + row = await userdb.execute( + text( + """ + SELECT CONCAT(`first_name`, ' ', `last_name`) AS `name` + FROM `users` + WHERE `id` = :uploader_id + """, + ), + parameters={"uploader_id": uploader_id}, + ) + result = row.one_or_none() + return result.name if result else None + + +async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]: + """Fetch all tags associated with a run from the `run_tag` table. + + The `id` column in `run_tag` refers to the run ID + """ + rows = await expdb.execute( + text( + """ + SELECT `tag` + FROM `run_tag` + WHERE `id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return [row.tag for row in rows.all()] + + +async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: + """Fetch the dataset(s) used as input for a run, with name and url. + + Joins `input_data` with `dataset` to include the dataset name and ARFF URL. + """ + rows = await expdb.execute( + text( + """ + SELECT `id`.`data` AS `did`, `d`.`name`, `d`.`url` + FROM `input_data` `id` + JOIN `dataset` `d` ON `id`.`data` = `d`.`did` + WHERE `id`.`run` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast("list[Row]", rows.all()) + + +async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: + """Fetch output files attached to a run from the `runfile` table. + + Typical entries include the description XML and predictions ARFF. + The `field` column holds the file label (e.g. "description", "predictions"). + + Note: the PHP response includes a deprecated `did` field hardcoded to "-1" + for each file. This implementation omits it entirely. + """ + rows = await expdb.execute( + text( + """ + SELECT `file_id`, `field` + FROM `runfile` + WHERE `source` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast("list[Row]", rows.all()) + + +async def get_evaluations( + run_id: int, + expdb: AsyncConnection, + *, + evaluation_engine_ids: list[int], +) -> list[Row]: + """Fetch evaluation metric results for a run. + + Joins `evaluation` with `math_function` to resolve the metric name + (the `evaluation` table stores only a `function_id`, not the name directly). + + Filters by `evaluation_engine_id IN (...)`. The list is configurable + via `config.toml [run] evaluation_engine_ids`. + Dynamic named parameters are used for aiomysql compatibility. + """ + if not evaluation_engine_ids: + return [] + + # Build :eid_0, :eid_1, ... placeholders — one per engine ID. + eid_params = {f"eid_{i}": eid for i, eid in enumerate(evaluation_engine_ids)} + placeholders = ", ".join(f":eid_{i}" for i in range(len(evaluation_engine_ids))) + + query = text( + f""" + SELECT `m`.`name`, `e`.`value`, `e`.`array_data` + FROM `evaluation` `e` + JOIN `math_function` `m` ON `e`.`function_id` = `m`.`id` + WHERE `e`.`source` = :run_id + AND `e`.`evaluation_engine_id` IN ({placeholders}) + """, # noqa: S608 # placeholders are trusted integer params, not user input + ) + rows = await expdb.execute( + query, + parameters={"run_id": run_id, **eid_params}, + ) + return cast("list[Row]", rows.all()) + + +async def get_task_type(task_id: int, expdb: AsyncConnection) -> str | None: + """Fetch the human-readable task type name for the task associated with a run. + + Joins `task` and `task_type` on `ttid` to resolve the name + (e.g. "Supervised Classification"). + """ + row = await expdb.execute( + text( + """ + SELECT `tt`.`name` + FROM `task` `t` + JOIN `task_type` `tt` ON `t`.`ttid` = `tt`.`ttid` + WHERE `t`.`task_id` = :task_id + """, + ), + parameters={"task_id": task_id}, + ) + result = row.one_or_none() + return result.name if result else None + + +async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> str | None: + """Fetch the evaluation measure configured for a task, if any. + + Queries `task_inputs` for the row where `input = 'evaluation_measures'`. + Returns None (not an empty string) when no such row exists, so callers + can treat a falsy result uniformly. + """ + row = await expdb.execute( + text( + """ + SELECT `value` + FROM `task_inputs` + WHERE `task_id` = :task_id + AND `input` = 'evaluation_measures' + """, + ), + parameters={"task_id": task_id}, + ) + result = row.one_or_none() + return result.value if result else None + + async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: """Get trace rows for a run from the trace table.""" rows = await expdb.execute( diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 37a7cecf..68ee0541 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,14 +1,29 @@ """Endpoints for run-related data.""" -from typing import Annotated +import asyncio +from typing import TYPE_CHECKING, Annotated, cast from fastapi import APIRouter, Depends + +if TYPE_CHECKING: + from sqlalchemy import Row from sqlalchemy.ext.asyncio import AsyncConnection +import config +import database.flows import database.runs +import database.setups from core.errors import RunNotFoundError, RunTraceNotFoundError -from routers.dependencies import expdb_connection -from schemas.runs import RunTrace, TraceIteration +from routers.dependencies import expdb_connection, userdb_connection +from schemas.runs import ( + EvaluationScore, + InputDataset, + OutputFile, + ParameterSetting, + Run, + RunTrace, + TraceIteration, +) router = APIRouter(prefix="/run", tags=["run"]) @@ -42,3 +57,122 @@ async def get_run_trace( for row in trace_rows ], ) + + +@router.post( + path="/{run_id}", + description="Provided for convenience, same as `GET` endpoint.", + response_model_exclude_none=True, +) +@router.get("/{run_id}", response_model_exclude_none=True) +async def get_run( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], + userdb: Annotated[AsyncConnection, Depends(userdb_connection)], +) -> Run: + """Get full metadata for a run by ID. + + No authentication or visibility check is performed — all runs are + publicly accessible. + """ + # Core run record — all other data depends on uploader, setup, and task_id. + run = await database.runs.get(run_id, expdb) + if run is None: + msg = f"Run {run_id} not found." + # Reuse RunNotFoundError and pass code=236 at the call site for + # backward compat with the PHP GET /run/{id} error code + raise RunNotFoundError(msg, code=236) + + # Evaluation engine IDs come from config.toml [run] so they can be + # extended when a new evaluation engine is deployed, without code changes. + engine_ids: list[int] = config.load_run_configuration().get("evaluation_engine_ids", [1]) + + # Fetch all independent data concurrently. + ( + uploader_name, + tags, + input_data_rows, + output_file_rows, + evaluation_rows, + task_type, + task_evaluation_measure, + setup, + parameter_rows, + ) = cast( + "tuple[str | None, list[str], list[Row], list[Row], list[Row], str | None, str" + "| None, Row | None, list[Row]]", + await asyncio.gather( + database.runs.get_uploader_name(run.uploader, userdb), + database.runs.get_tags(run_id, expdb), + database.runs.get_input_data(run_id, expdb), + database.runs.get_output_files(run_id, expdb), + database.runs.get_evaluations(run_id, expdb, evaluation_engine_ids=engine_ids), + database.runs.get_task_type(run.task_id, expdb), + database.runs.get_task_evaluation_measure(run.task_id, expdb), + database.setups.get(run.setup, expdb), + database.setups.get_parameters(run.setup, expdb), + ), + ) + + # Flow is fetched after the gather because it requires setup.implementation_id. + # flows.get() selects fullName AS full_name for reliable case-insensitive access. + flow = await database.flows.get(setup.implementation_id, expdb) if setup else None + + # Build parameter_setting list from the denormalised parameter rows + # returned by database.setups.get_parameters (which already JOINs input + implementation). + parameter_settings = [ + ParameterSetting( + name=p["name"], + value=p["value"], + component=p["flow_id"], # implementation_id of the sub-flow owning this param + ) + for p in parameter_rows + ] + + input_datasets = [ + InputDataset(did=row.did, name=row.name, url=row.url) for row in input_data_rows + ] + + # runfile.field is the file label (e.g. "description", "predictions") + output_files = [OutputFile(file_id=row.file_id, name=row.field) for row in output_file_rows] + + evaluations = [ + EvaluationScore( + name=row.name, + # Whole-number floats (e.g. counts) are converted to int to match PHP's + # integer representation. e.g. 253.0 → 253, 0.0 → 0. + value=int(row.value) + if isinstance(row.value, float) and row.value.is_integer() + else row.value, + array_data=row.array_data, + ) + for row in evaluation_rows + ] + + # Normalise task_evaluation_measure: empty string → None so the field is + # excluded entirely by response_model_exclude_none=True (matches PHP behaviour + # of returning "" but we opt to omit rather than return an empty string). + normalised_measure = task_evaluation_measure or None + + # error_message is NULL in the DB when the run has no error. + # The PHP response returns an empty array [] in that case. + error_messages = [run.error_message] if run.error_message else [] + + return Run( + run_id=run_id, + uploader=run.uploader, + uploader_name=uploader_name, + task_id=run.task_id, + task_type=task_type, + task_evaluation_measure=normalised_measure, + flow_id=setup.implementation_id if setup else 0, + flow_name=flow.full_name if flow else None, + setup_id=run.setup, + setup_string=setup.setup_string if setup else None, + parameter_setting=parameter_settings, + error_message=error_messages, + tag=tags, + # Preserve PHP envelope structure for backward compat + input_data={"dataset": input_datasets}, + output_data={"file": output_files, "evaluation": evaluations}, + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py index 857f4921..991d0059 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -1,6 +1,6 @@ """Pydantic schemas for run-related endpoints.""" -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field class TraceIteration(BaseModel): @@ -19,3 +19,94 @@ class RunTrace(BaseModel): run_id: int trace: list[TraceIteration] + + +class ParameterSetting(BaseModel): + """A single hyperparameter value used in a run's setup. + + `component` is the `implementation_id` of the flow that defines this + parameter — useful when a setup spans multiple sub-flows (components). + `value` is None when the parameter was not explicitly set (uses default). + """ + + name: str + value: str | None + component: int # = input.implementation_id (flow_id of the owning component) + + +class InputDataset(BaseModel): + """A dataset used as input for a run. + + Sourced from `input_data` JOIN `dataset`. `name` and `url` are fetched + from the `dataset` table and match the values PHP returns. + """ + + did: int + name: str + url: str # ARFF download URL stored in dataset.url + + +class OutputFile(BaseModel): + """An output file produced by or attached to a run. + + Sourced from the `runfile` table. `name` is the file label + (e.g. "description", "predictions"). + + Note: the legacy PHP response included a `did` field hardcoded to "-1" + for every entry here. It originates from a deprecated idea that run outputs + could create new datasets. It is intentionally omitted in this implementation. + """ + + file_id: int + name: str # label as stored in runfile.field, e.g. "description", "predictions" + + +class EvaluationScore(BaseModel): + """An evaluation metric score for a run. + + Sourced from a JOIN of `evaluation` and `math_function`. + `array_data` holds per-fold/per-class breakdowns when available; + `value` holds the aggregate scalar. + + Note: the `evaluation` table does NOT contain `repeat` or `fold` columns. + Only aggregate metrics are available. (Confirmed against the PHP query + in issue #37 which also only selects name, value, array_data.) + """ + + name: str + value: float | int | None # whole numbers returned as int to match PHP + array_data: str | None + + +class Run(BaseModel): + """Full metadata response for a single OpenML run. + + Notes: + - `error_message` is serialized as "error". + - `error_message` is [] (empty list) when the DB column is NULL. + - `task_evaluation_measure` is omitted when null or empty. + + """ + + model_config = ConfigDict(populate_by_name=True) + + run_id: int + uploader: int # user ID of the uploader + uploader_name: str | None # fetched from the separate openml DB (users table) + task_id: int + task_type: str | None # e.g. "Supervised Classification" + task_evaluation_measure: str | None # omitted when null/empty (not returned) + flow_id: int # = algorithm_setup.implementation_id + flow_name: str | None # = implementation.fullName + setup_id: int # = algorithm_setup.sid + setup_string: str | None # human-readable description of the setup + parameter_setting: list[ParameterSetting] + # Serialized as "error" in JSON to match the PHP response key. + # At the Python level we keep the name error_message for clarity. + error_message: list[str] = Field(serialization_alias="error") # [] when NULL in DB + tag: list[str] + input_data: dict[str, list[InputDataset]] # {"dataset": [...]} + output_data: dict[ + str, + list[OutputFile | EvaluationScore], + ] # {"file": [...], "evaluation": [...]} diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py new file mode 100644 index 00000000..0a400fef --- /dev/null +++ b/tests/routers/openml/runs_get_test.py @@ -0,0 +1,383 @@ +"""Tests for GET /run/{id} and POST /run/{id} endpoints.""" + +import asyncio +from http import HTTPStatus +from typing import Any + +import deepdiff +import httpx +import pytest +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +from core.conversions import nested_num_to_str, nested_remove_single_element_list + +# ── Fixtures assume run 24 exists in the test DB (confirmed in research) ── +_RUN_ID = 24 +_MISSING_RUN_ID = 999_999_999 +_RUN_NOT_FOUND_CODE = "236" # PHP compat error code (not 220) + +_RUN_UPLOADER_ID = 1159 +_RUN_TASK_ID = 115 +_RUN_FLOW_ID = 19 +_RUN_SETUP_ID = 2 +_RUN_DATASET_ID = 20 +_DESCRIPTION_FILE_ID = 182 +_PREDICTIONS_FILE_ID = 183 + + +# ════════════════════════════════════════════════════════════════════ +# Happy-path API tests (use py_api httpx client) +# ════════════════════════════════════════════════════════════════════ + + +async def test_get_run_status_ok(py_api: httpx.AsyncClient) -> None: + """GET /run/{id} returns 200 for a known run.""" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + +async def test_post_run_status_ok(py_api: httpx.AsyncClient) -> None: + """POST /run/{id} returns 200 — convenience alias parity.""" + response = await py_api.post(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + +async def test_get_and_post_run_identical(py_api: httpx.AsyncClient) -> None: + """GET and POST /run/{id} return identical JSON bodies.""" + get_resp, post_resp = await asyncio.gather( + py_api.get(f"/run/{_RUN_ID}"), + py_api.post(f"/run/{_RUN_ID}"), + ) + assert get_resp.status_code == HTTPStatus.OK + assert post_resp.status_code == HTTPStatus.OK + assert get_resp.json() == post_resp.json() + + +async def test_get_run_top_level_shape(py_api: httpx.AsyncClient) -> None: + """Response contains all expected top-level keys.""" + response = await py_api.get(f"/run/{_RUN_ID}") + run = response.json() + expected_keys = { + "run_id", + "uploader", + "uploader_name", + "task_id", + "task_type", + "flow_id", + "flow_name", + "setup_id", + "setup_string", + "parameter_setting", + "error", + "tag", + "input_data", + "output_data", + } + assert expected_keys <= run.keys(), f"Missing keys: {expected_keys - run.keys()}" + + +async def test_get_run_known_values(py_api: httpx.AsyncClient) -> None: + """Run 24 returns the exact values confirmed against the DB.""" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + run = response.json() + + # Core identifiers + assert run["run_id"] == _RUN_ID + assert run["uploader"] == _RUN_UPLOADER_ID + assert run["uploader_name"] == "Cynthia Glover" + assert run["task_id"] == _RUN_TASK_ID + assert run["task_type"] == "Supervised Classification" + assert run["flow_id"] == _RUN_FLOW_ID + assert run["setup_id"] == _RUN_SETUP_ID + assert "Python_3.10.5" in run["setup_string"] + + # Tags + assert "openml-python" in run["tag"] + + # Error — NULL in DB → empty list + assert run["error"] == [] + + +async def test_get_run_input_data_shape(py_api: httpx.AsyncClient) -> None: + """input_data has the PHP envelope structure {"dataset": [...]}.""" + response = await py_api.get(f"/run/{_RUN_ID}") + run = response.json() + assert "dataset" in run["input_data"] + datasets = run["input_data"]["dataset"] + assert isinstance(datasets, list) + assert len(datasets) > 0 + dataset = datasets[0] + assert "did" in dataset + assert "name" in dataset + assert "url" in dataset + # Run 24 uses diabetes dataset (did=20), confirmed in DB + assert dataset["did"] == _RUN_DATASET_ID + assert dataset["name"] == "diabetes" + + +async def test_get_run_output_data_shape(py_api: httpx.AsyncClient) -> None: + """output_data has {"file": [...], "evaluation": [...]} structure.""" + response = await py_api.get(f"/run/{_RUN_ID}") + run = response.json() + assert "file" in run["output_data"] + assert "evaluation" in run["output_data"] + + files = run["output_data"]["file"] + assert isinstance(files, list) + assert len(files) > 0 + file_ = files[0] + assert "file_id" in file_ + assert "name" in file_ + # Deprecated `did: "-1"` must NOT be present (intentionally omitted) + assert "did" not in file_ + + evaluations = run["output_data"]["evaluation"] + assert isinstance(evaluations, list) + assert len(evaluations) > 0 + eval_ = evaluations[0] + assert "name" in eval_ + assert "value" in eval_ + + +async def test_get_run_output_files_known(py_api: httpx.AsyncClient) -> None: + """Run 24 output files are description (182) and predictions (183).""" + response = await py_api.get(f"/run/{_RUN_ID}") + files = response.json()["output_data"]["file"] + file_map = {f["name"]: f["file_id"] for f in files} + assert file_map.get("description") == _DESCRIPTION_FILE_ID + assert file_map.get("predictions") == _PREDICTIONS_FILE_ID + + +async def test_get_run_evaluation_known(py_api: httpx.AsyncClient) -> None: + """Run 24 evaluations include area_under_roc_curve.""" + response = await py_api.get(f"/run/{_RUN_ID}") + evals = response.json()["output_data"]["evaluation"] + eval_names = {e["name"] for e in evals} + assert "area_under_roc_curve" in eval_names + + +async def test_get_run_integer_evaluation_values(py_api: httpx.AsyncClient) -> None: + """Whole-number floats in evaluations are returned as int (PHP compat).""" + response = await py_api.get(f"/run/{_RUN_ID}") + evals = response.json()["output_data"]["evaluation"] + for ev in evals: + if ev["value"] is not None and isinstance(ev["value"], float): + # If it's a float in the JSON, it must NOT be a whole number + # (whole-number floats should have been cast to int already) + assert ev["value"] != int(ev["value"]), ( + f"Expected {ev['name']} value {ev['value']} to be int, not float" + ) + + +async def test_get_run_parameter_setting_shape(py_api: httpx.AsyncClient) -> None: + """parameter_setting entries have name, value, component keys.""" + response = await py_api.get(f"/run/{_RUN_ID}") + params = response.json()["parameter_setting"] + assert isinstance(params, list) + for p in params: + assert "name" in p + assert "value" in p + assert "component" in p + assert isinstance(p["component"], int) + + +async def test_get_run_not_found(py_api: httpx.AsyncClient) -> None: + """Non-existent run returns 404 with error code 236 (PHP compat).""" + response = await py_api.get(f"/run/{_MISSING_RUN_ID}") + assert response.status_code == HTTPStatus.NOT_FOUND + error = response.json() + # Verify PHP-compat error code + assert str(error.get("code")) == _RUN_NOT_FOUND_CODE + + +async def test_get_run_invalid_id_type(py_api: httpx.AsyncClient) -> None: + """Non-integer run ID returns 422 Unprocessable Entity.""" + response = await py_api.get("/run/not-a-number") + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +# ════════════════════════════════════════════════════════════════════ +# Functional / unit-level tests (call database functions directly) +# ════════════════════════════════════════════════════════════════════ + + +async def test_db_get_run_exists(expdb_test: AsyncConnection) -> None: + """database.runs.get returns a row for run 24.""" + row = await database.runs.get(_RUN_ID, expdb_test) + assert row is not None + assert row.rid == _RUN_ID + assert row.uploader == _RUN_UPLOADER_ID + assert row.task_id == _RUN_TASK_ID + assert row.setup == _RUN_SETUP_ID + assert row.error_message is None # no error for this run + + +async def test_db_get_run_missing(expdb_test: AsyncConnection) -> None: + """database.runs.get returns None for a non-existent run.""" + row = await database.runs.get(_MISSING_RUN_ID, expdb_test) + assert row is None + + +async def test_db_exist_true(expdb_test: AsyncConnection) -> None: + """database.runs.exist returns True for run 24.""" + assert await database.runs.exist(_RUN_ID, expdb_test) is True + + +async def test_db_exist_false(expdb_test: AsyncConnection) -> None: + """database.runs.exist returns False for a missing run.""" + assert await database.runs.exist(_MISSING_RUN_ID, expdb_test) is False + + +async def test_db_get_tags(expdb_test: AsyncConnection) -> None: + """database.runs.get_tags returns expected tags for run 24.""" + tags = await database.runs.get_tags(_RUN_ID, expdb_test) + assert isinstance(tags, list) + assert "openml-python" in tags + + +async def test_db_get_input_data(expdb_test: AsyncConnection) -> None: + """database.runs.get_input_data returns did=20 (diabetes) for run 24.""" + rows = await database.runs.get_input_data(_RUN_ID, expdb_test) + assert len(rows) >= 1 + dids = [r.did for r in rows] + assert _RUN_DATASET_ID in dids + + +async def test_db_get_output_files(expdb_test: AsyncConnection) -> None: + """database.runs.get_output_files returns description and predictions files.""" + rows = await database.runs.get_output_files(_RUN_ID, expdb_test) + file_map = {r.field: r.file_id for r in rows} + assert file_map.get("description") == _DESCRIPTION_FILE_ID + assert file_map.get("predictions") == _PREDICTIONS_FILE_ID + + +async def test_db_get_evaluations(expdb_test: AsyncConnection) -> None: + """database.runs.get_evaluations returns metrics including area_under_roc_curve.""" + rows = await database.runs.get_evaluations(_RUN_ID, expdb_test, evaluation_engine_ids=[1]) + assert len(rows) > 0 + names = {r.name for r in rows} + assert "area_under_roc_curve" in names + + +async def test_db_get_evaluations_empty_engine_list(expdb_test: AsyncConnection) -> None: + """get_evaluations with no engine IDs returns an empty list (not an error).""" + rows = await database.runs.get_evaluations(_RUN_ID, expdb_test, evaluation_engine_ids=[]) + assert isinstance(rows, list) + + +async def test_db_get_task_type(expdb_test: AsyncConnection) -> None: + """database.runs.get_task_type returns 'Supervised Classification' for task 115.""" + task_type = await database.runs.get_task_type(115, expdb_test) + assert task_type == "Supervised Classification" + + +async def test_db_get_task_evaluation_measure_missing(expdb_test: AsyncConnection) -> None: + """get_task_evaluation_measure returns None (not '') when absent.""" + measure = await database.runs.get_task_evaluation_measure(115, expdb_test) + assert measure is None + + +async def test_db_get_uploader_name(user_test: AsyncConnection) -> None: + """database.runs.get_uploader_name returns 'Cynthia Glover' for user 1159.""" + name = await database.runs.get_uploader_name(1159, user_test) + assert name == "Cynthia Glover" + + +async def test_db_get_uploader_name_missing(user_test: AsyncConnection) -> None: + """get_uploader_name returns None for a non-existent user.""" + name = await database.runs.get_uploader_name(_MISSING_RUN_ID, user_test) + assert name is None + + +# ════════════════════════════════════════════════════════════════════ +# Migration tests (Python API vs PHP API parity) +# ════════════════════════════════════════════════════════════════════ + +# Regex paths excluded from DeepDiff — only genuinely untestable fields. +_EXCLUDE_PATHS = [ + # [1] PHP hardcodes did="-1" in output_data.file; Python omits it (deprecated). + r"root\['run'\]\['output_data'\]\['file'\]\[\d+\]\['did'\]", + # [2] PHP generates output file URLs from its own server_url config. + # Python does not yet have a file download endpoint, so URLs differ by design. + r"root\['run'\]\['output_data'\]\['file'\]\[\d+\]\['url'\]", +] + + +def _normalize_py_run(py_run: dict[str, Any]) -> dict[str, Any]: + """Normalize a Python run response to match the PHP response format.""" + run = py_run.copy() + + # Collapse single-element lists to match PHP XML-to-JSON behaviour. + run = nested_remove_single_element_list(run) + + # PHP returns all numbers as strings — convert to match. + run = nested_num_to_str(run) + + # PHP wraps the response envelope. + return {"run": run} + + +# Run IDs to test, including a non-existent one to verify error parity. +_RUN_IDS = [*range(24, 35), 999_999_999] + + +@pytest.mark.parametrize("run_id", _RUN_IDS) +async def test_get_run_equal( + run_id: int, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + """Python and PHP run responses are equivalent after normalization.""" + py_response, php_response = await asyncio.gather( + py_api.get(f"/run/{run_id}"), + php_api.get(f"/run/{run_id}"), + ) + + # Error case: run does not exist. + # PHP returns 412 PRECONDITION_FAILED; Python returns 404 NOT_FOUND. + if php_response.status_code != HTTPStatus.OK: + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.status_code == HTTPStatus.NOT_FOUND + php_code = php_response.json()["error"]["code"] + py_code = py_response.json()["code"] + assert py_code == php_code + return + + assert py_response.status_code == HTTPStatus.OK + + py_normalized = _normalize_py_run(py_response.json()) + php_json = php_response.json() + + # PHP duplicates evaluation entries natively for each fold, and also provides + # an aggregate with `repeat="0"` and `fold="0"`. The Python API correctly provides + # only the aggregate row (and array_data string). + # To match without complex deepdiff matchers, simply verify the base aggregate entries. + if ( + "run" in php_json + and "output_data" in php_json["run"] + and "evaluation" in php_json["run"]["output_data"] + ): + php_evals = php_json["run"]["output_data"]["evaluation"] + if isinstance(php_evals, list): + php_json["run"]["output_data"]["evaluation"] = [ + ev for ev in php_evals if "repeat" not in ev and "fold" not in ev + ] + elif isinstance(php_evals, dict) and ("repeat" in php_evals or "fold" in php_evals): + # nested_remove_single_element_list removes lists if there's only 1 element, but PHP + # original JSON might have had only 1 base evaluation if no others existed. + # But PHP returns a list anyway if duplicates exist. If they don't, it's a dict. + php_json["run"]["output_data"]["evaluation"] = [] + + # PHP sometimes includes empty `error` property instead of an empty list when no error occurred + # DeepDiff takes care of it automatically because we didn't see error diffs. + + differences = deepdiff.diff.DeepDiff( + py_normalized, + php_json, + ignore_order=True, + ignore_numeric_type_changes=True, + exclude_regex_paths=_EXCLUDE_PATHS, + ) + assert not differences, f"Differences for run {run_id}: {differences}" From 4ca818c8ff456204fa0ab77ece55919140ac9daf Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Wed, 22 Apr 2026 11:20:09 +0530 Subject: [PATCH 2/4] address bot feedback, add test_evaluation test --- src/database/runs.py | 18 ++- src/routers/openml/runs.py | 160 ++++++++++++++++---------- src/schemas/runs.py | 24 ++-- tests/routers/openml/runs_get_test.py | 41 ++++++- 4 files changed, 161 insertions(+), 82 deletions(-) diff --git a/src/database/runs.py b/src/database/runs.py index 5f7f2afb..788a605a 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Row, text +from sqlalchemy import Row, bindparam, text from sqlalchemy.ext.asyncio import AsyncConnection @@ -50,7 +50,7 @@ async def get_uploader_name(uploader_id: int, userdb: AsyncConnection) -> str | row = await userdb.execute( text( """ - SELECT CONCAT(`first_name`, ' ', `last_name`) AS `name` + SELECT CONCAT_WS(' ', `first_name`, `last_name`) AS `name` FROM `users` WHERE `id` = :uploader_id """, @@ -138,22 +138,18 @@ async def get_evaluations( if not evaluation_engine_ids: return [] - # Build :eid_0, :eid_1, ... placeholders — one per engine ID. - eid_params = {f"eid_{i}": eid for i, eid in enumerate(evaluation_engine_ids)} - placeholders = ", ".join(f":eid_{i}" for i in range(len(evaluation_engine_ids))) - query = text( - f""" + """ SELECT `m`.`name`, `e`.`value`, `e`.`array_data` FROM `evaluation` `e` JOIN `math_function` `m` ON `e`.`function_id` = `m`.`id` WHERE `e`.`source` = :run_id - AND `e`.`evaluation_engine_id` IN ({placeholders}) - """, # noqa: S608 # placeholders are trusted integer params, not user input - ) + AND `e`.`evaluation_engine_id` IN :engine_ids + """, + ).bindparams(bindparam("engine_ids", expanding=True)) rows = await expdb.execute( query, - parameters={"run_id": run_id, **eid_params}, + parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids}, ) return cast("list[Row]", rows.all()) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 68ee0541..113c5817 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,6 +1,7 @@ """Endpoints for run-related data.""" import asyncio +from dataclasses import dataclass from typing import TYPE_CHECKING, Annotated, cast from fastapi import APIRouter, Depends @@ -17,7 +18,9 @@ from routers.dependencies import expdb_connection, userdb_connection from schemas.runs import ( EvaluationScore, + InputData, InputDataset, + OutputData, OutputFile, ParameterSetting, Run, @@ -59,35 +62,28 @@ async def get_run_trace( ) -@router.post( - path="/{run_id}", - description="Provided for convenience, same as `GET` endpoint.", - response_model_exclude_none=True, -) -@router.get("/{run_id}", response_model_exclude_none=True) -async def get_run( - run_id: int, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)], - userdb: Annotated[AsyncConnection, Depends(userdb_connection)], -) -> Run: - """Get full metadata for a run by ID. +@dataclass +class RunContext: + """Helper context to store concurrently fetched run dependencies.""" - No authentication or visibility check is performed — all runs are - publicly accessible. - """ - # Core run record — all other data depends on uploader, setup, and task_id. - run = await database.runs.get(run_id, expdb) - if run is None: - msg = f"Run {run_id} not found." - # Reuse RunNotFoundError and pass code=236 at the call site for - # backward compat with the PHP GET /run/{id} error code - raise RunNotFoundError(msg, code=236) + uploader_name: str | None + tags: list[str] + input_data_rows: list["Row"] + output_file_rows: list["Row"] + evaluation_rows: list["Row"] + task_type: str | None + task_evaluation_measure: str | None + setup: "Row | None" + parameter_rows: list["Row"] - # Evaluation engine IDs come from config.toml [run] so they can be - # extended when a new evaluation engine is deployed, without code changes. - engine_ids: list[int] = config.load_run_configuration().get("evaluation_engine_ids", [1]) - # Fetch all independent data concurrently. +async def _load_run_context( + run: "Row", + run_id: int, + expdb: AsyncConnection, + userdb: AsyncConnection, + engine_ids: list[int], +) -> RunContext: ( uploader_name, tags, @@ -99,8 +95,8 @@ async def get_run( setup, parameter_rows, ) = cast( - "tuple[str | None, list[str], list[Row], list[Row], list[Row], str | None, str" - "| None, Row | None, list[Row]]", + "tuple[str | None, list[str], list[Row], list[Row], list[Row], str | None, str |" + "None, Row | None, list[Row]]", await asyncio.gather( database.runs.get_uploader_name(run.uploader, userdb), database.runs.get_tags(run_id, expdb), @@ -113,66 +109,108 @@ async def get_run( database.setups.get_parameters(run.setup, expdb), ), ) + return RunContext( + uploader_name=uploader_name, + tags=tags, + input_data_rows=input_data_rows, + output_file_rows=output_file_rows, + evaluation_rows=evaluation_rows, + task_type=task_type, + task_evaluation_measure=task_evaluation_measure, + setup=setup, + parameter_rows=parameter_rows, + ) - # Flow is fetched after the gather because it requires setup.implementation_id. - # flows.get() selects fullName AS full_name for reliable case-insensitive access. - flow = await database.flows.get(setup.implementation_id, expdb) if setup else None - # Build parameter_setting list from the denormalised parameter rows - # returned by database.setups.get_parameters (which already JOINs input + implementation). - parameter_settings = [ +def _build_parameter_settings(parameter_rows: list["Row"]) -> list[ParameterSetting]: + return [ ParameterSetting( name=p["name"], value=p["value"], - component=p["flow_id"], # implementation_id of the sub-flow owning this param + component=p["flow_id"], ) for p in parameter_rows ] - input_datasets = [ - InputDataset(did=row.did, name=row.name, url=row.url) for row in input_data_rows - ] - # runfile.field is the file label (e.g. "description", "predictions") - output_files = [OutputFile(file_id=row.file_id, name=row.field) for row in output_file_rows] +def _build_input_datasets(rows: list["Row"]) -> list[InputDataset]: + return [InputDataset(did=row.did, name=row.name, url=row.url) for row in rows] + + +def _build_output_files(rows: list["Row"]) -> list[OutputFile]: + return [OutputFile(file_id=row.file_id, name=row.field) for row in rows] + - evaluations = [ +def _build_evaluations(rows: list["Row"]) -> list[EvaluationScore]: + def _normalise_value(v: object) -> object: + if isinstance(v, (int, float)): + return int(v) if float(v).is_integer() else v + if isinstance(v, str): + try: + f = float(v) + return int(f) if f.is_integer() else f + except ValueError: + return None + return None + + return [ EvaluationScore( name=row.name, - # Whole-number floats (e.g. counts) are converted to int to match PHP's - # integer representation. e.g. 253.0 → 253, 0.0 → 0. - value=int(row.value) - if isinstance(row.value, float) and row.value.is_integer() - else row.value, + value=_normalise_value(row.value), array_data=row.array_data, ) - for row in evaluation_rows + for row in rows ] - # Normalise task_evaluation_measure: empty string → None so the field is - # excluded entirely by response_model_exclude_none=True (matches PHP behaviour - # of returning "" but we opt to omit rather than return an empty string). - normalised_measure = task_evaluation_measure or None - # error_message is NULL in the DB when the run has no error. - # The PHP response returns an empty array [] in that case. +@router.post( + path="/{run_id}", + description="Provided for convenience, same as `GET` endpoint.", + response_model_exclude_none=True, +) +@router.get("/{run_id}", response_model_exclude_none=True) +async def get_run( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], + userdb: Annotated[AsyncConnection, Depends(userdb_connection)], +) -> Run: + """Get full metadata for a run by ID. + + No authentication or visibility check is performed — all runs are + publicly accessible. + """ + run = await database.runs.get(run_id, expdb) + if run is None: + msg = f"Run {run_id} not found." + raise RunNotFoundError(msg, code=236) + + engine_ids: list[int] = config.load_run_configuration().get("evaluation_engine_ids", [1]) + ctx = await _load_run_context(run, run_id, expdb, userdb, engine_ids) + + flow = await database.flows.get(ctx.setup.implementation_id, expdb) if ctx.setup else None + + parameter_settings = _build_parameter_settings(ctx.parameter_rows) + input_datasets = _build_input_datasets(ctx.input_data_rows) + output_files = _build_output_files(ctx.output_file_rows) + evaluations = _build_evaluations(ctx.evaluation_rows) + + normalised_measure = ctx.task_evaluation_measure or None error_messages = [run.error_message] if run.error_message else [] return Run( run_id=run_id, uploader=run.uploader, - uploader_name=uploader_name, + uploader_name=ctx.uploader_name, task_id=run.task_id, - task_type=task_type, + task_type=ctx.task_type, task_evaluation_measure=normalised_measure, - flow_id=setup.implementation_id if setup else 0, + flow_id=ctx.setup.implementation_id if ctx.setup else None, flow_name=flow.full_name if flow else None, setup_id=run.setup, - setup_string=setup.setup_string if setup else None, + setup_string=ctx.setup.setup_string if ctx.setup else None, parameter_setting=parameter_settings, error_message=error_messages, - tag=tags, - # Preserve PHP envelope structure for backward compat - input_data={"dataset": input_datasets}, - output_data={"file": output_files, "evaluation": evaluations}, + tag=ctx.tags, + input_data=InputData(dataset=input_datasets), + output_data=OutputData(file=output_files, evaluation=evaluations), ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py index 991d0059..3ab16110 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -78,6 +78,19 @@ class EvaluationScore(BaseModel): array_data: str | None +class InputData(BaseModel): + """Wrapper for input datasets configuration.""" + + dataset: list[InputDataset] + + +class OutputData(BaseModel): + """Wrapper for output files and evaluations.""" + + file: list[OutputFile] + evaluation: list[EvaluationScore] + + class Run(BaseModel): """Full metadata response for a single OpenML run. @@ -96,17 +109,14 @@ class Run(BaseModel): task_id: int task_type: str | None # e.g. "Supervised Classification" task_evaluation_measure: str | None # omitted when null/empty (not returned) - flow_id: int # = algorithm_setup.implementation_id + flow_id: int | None = None # = algorithm_setup.implementation_id; None when no setup flow_name: str | None # = implementation.fullName - setup_id: int # = algorithm_setup.sid + setup_id: int | None = None # = algorithm_setup.sid; None when run has no setup setup_string: str | None # human-readable description of the setup parameter_setting: list[ParameterSetting] # Serialized as "error" in JSON to match the PHP response key. # At the Python level we keep the name error_message for clarity. error_message: list[str] = Field(serialization_alias="error") # [] when NULL in DB tag: list[str] - input_data: dict[str, list[InputDataset]] # {"dataset": [...]} - output_data: dict[ - str, - list[OutputFile | EvaluationScore], - ] # {"file": [...], "evaluation": [...]} + input_data: InputData + output_data: OutputData diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py index 0a400fef..40965123 100644 --- a/tests/routers/openml/runs_get_test.py +++ b/tests/routers/openml/runs_get_test.py @@ -11,11 +11,12 @@ import database.runs from core.conversions import nested_num_to_str, nested_remove_single_element_list +from routers.openml.runs import _build_evaluations # ── Fixtures assume run 24 exists in the test DB (confirmed in research) ── _RUN_ID = 24 _MISSING_RUN_ID = 999_999_999 -_RUN_NOT_FOUND_CODE = "236" # PHP compat error code (not 220) +_RUN_NOT_FOUND_CODE = "236" _RUN_UPLOADER_ID = 1159 _RUN_TASK_ID = 115 @@ -96,7 +97,7 @@ async def test_get_run_known_values(py_api: httpx.AsyncClient) -> None: # Tags assert "openml-python" in run["tag"] - # Error — NULL in DB → empty list + # Error — NULL in DB -> empty list assert run["error"] == [] @@ -192,6 +193,14 @@ async def test_get_run_not_found(py_api: httpx.AsyncClient) -> None: assert str(error.get("code")) == _RUN_NOT_FOUND_CODE +async def test_task_evaluation_measure_omitted_when_null(py_api: httpx.AsyncClient) -> None: + """task_evaluation_measure is not present in JSON when no measure is configured.""" + # Run 24 is known to not have a task evaluation measure (verified in db test) + response = await py_api.get(f"/run/{_RUN_ID}") + run = response.json() + assert "task_evaluation_measure" not in run + + async def test_get_run_invalid_id_type(py_api: httpx.AsyncClient) -> None: """Non-integer run ID returns 422 Unprocessable Entity.""" response = await py_api.get("/run/not-a-number") @@ -340,7 +349,7 @@ async def test_get_run_equal( if php_response.status_code != HTTPStatus.OK: assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED assert py_response.status_code == HTTPStatus.NOT_FOUND - php_code = php_response.json()["error"]["code"] + php_code = php_response.json().get("error", {}).get("code") py_code = py_response.json()["code"] assert py_code == php_code return @@ -381,3 +390,29 @@ async def test_get_run_equal( exclude_regex_paths=_EXCLUDE_PATHS, ) assert not differences, f"Differences for run {run_id}: {differences}" + + +def test_build_evaluations_coverage() -> None: + """Ensure _build_evaluations string-normalization branches are covered.""" + + class MockRow: + def __init__(self, name: str, value: object, array_data: str | None = None) -> None: + self.name = name + self.value = value + self.array_data = array_data + + rows = [ + MockRow("float_val", 1.0), + MockRow("str_float", "2.0"), + MockRow("str_text", "not_a_number"), + MockRow("unhandled_type", ["list"]), + ] + evals = _build_evaluations(rows) + + values = {e.name: e.value for e in evals} + expected_one = 1 + expected_two = 2 + assert values["float_val"] == expected_one + assert values["str_float"] == expected_two + assert values["str_text"] is None + assert values["unhandled_type"] is None From 4461c6363a210a60809451ed0e7112261326a95a Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Wed, 22 Apr 2026 15:45:00 +0530 Subject: [PATCH 3/4] add task_evaluation_measure test, address bot review --- src/config.py | 1 + tests/routers/openml/runs_get_test.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/config.py b/src/config.py index 011a5624..04dd0fc9 100644 --- a/src/config.py +++ b/src/config.py @@ -54,6 +54,7 @@ def load_routing_configuration(file: Path = _config_file) -> TomlTable: return typing.cast("TomlTable", _load_configuration(file)["routing"]) +@functools.cache def load_run_configuration(file: Path = _config_file) -> TomlTable: return typing.cast("TomlTable", _load_configuration(file).get("run", {})) diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py index 40965123..9602006c 100644 --- a/tests/routers/openml/runs_get_test.py +++ b/tests/routers/openml/runs_get_test.py @@ -3,6 +3,7 @@ import asyncio from http import HTTPStatus from typing import Any +from unittest.mock import AsyncMock, patch import deepdiff import httpx @@ -201,6 +202,23 @@ async def test_task_evaluation_measure_omitted_when_null(py_api: httpx.AsyncClie assert "task_evaluation_measure" not in run +async def test_task_evaluation_measure_present_when_configured( + py_api: httpx.AsyncClient, +) -> None: + """task_evaluation_measure is present and matches DB when a measure is configured.""" + # Since the test database does not have a run with an evaluation measure, we mock the DB fetch + with patch( + "routers.openml.runs.database.runs.get_task_evaluation_measure", new_callable=AsyncMock + ) as mock_get_measure: + mock_get_measure.return_value = "predictive_accuracy" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + run = response.json() + assert "task_evaluation_measure" in run + assert run["task_evaluation_measure"] == "predictive_accuracy" + + async def test_get_run_invalid_id_type(py_api: httpx.AsyncClient) -> None: """Non-integer run ID returns 422 Unprocessable Entity.""" response = await py_api.get("/run/not-a-number") From 34c0e28a61465632a1011056bfccb104df3073c3 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Thu, 23 Apr 2026 10:47:33 +0530 Subject: [PATCH 4/4] address bot feedback for run denpoint tests, add async mock DB, cleanup --- tests/routers/openml/runs_get_test.py | 62 ++++++++++++++++++++------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py index 9602006c..5db054bf 100644 --- a/tests/routers/openml/runs_get_test.py +++ b/tests/routers/openml/runs_get_test.py @@ -2,7 +2,7 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import Any, NamedTuple from unittest.mock import AsyncMock, patch import deepdiff @@ -17,6 +17,7 @@ # ── Fixtures assume run 24 exists in the test DB (confirmed in research) ── _RUN_ID = 24 _MISSING_RUN_ID = 999_999_999 +_MISSING_USER_ID = 999_999_999 _RUN_NOT_FOUND_CODE = "236" _RUN_UPLOADER_ID = 1159 @@ -102,6 +103,34 @@ async def test_get_run_known_values(py_api: httpx.AsyncClient) -> None: assert run["error"] == [] +async def test_get_run_non_empty_error(py_api: httpx.AsyncClient) -> None: + """A run with a non-null error_message is serialized as a single-item error list.""" + + # Since the test database does not have a run with an error, we mock the DB fetch + class MockRunRow(NamedTuple): + rid: int + uploader: int + setup: int + task_id: int + error_message: str + + mock_row = MockRunRow( + rid=_RUN_ID, + uploader=_RUN_UPLOADER_ID, + setup=_RUN_SETUP_ID, + task_id=_RUN_TASK_ID, + error_message="Some error from the backend", + ) + + with patch("routers.openml.runs.database.runs.get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_row + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + run = response.json() + assert run["error"] == ["Some error from the backend"] + + async def test_get_run_input_data_shape(py_api: httpx.AsyncClient) -> None: """input_data has the PHP envelope structure {"dataset": [...]}.""" response = await py_api.get(f"/run/{_RUN_ID}") @@ -291,30 +320,30 @@ async def test_db_get_evaluations(expdb_test: AsyncConnection) -> None: async def test_db_get_evaluations_empty_engine_list(expdb_test: AsyncConnection) -> None: """get_evaluations with no engine IDs returns an empty list (not an error).""" rows = await database.runs.get_evaluations(_RUN_ID, expdb_test, evaluation_engine_ids=[]) - assert isinstance(rows, list) + assert rows == [] async def test_db_get_task_type(expdb_test: AsyncConnection) -> None: """database.runs.get_task_type returns 'Supervised Classification' for task 115.""" - task_type = await database.runs.get_task_type(115, expdb_test) + task_type = await database.runs.get_task_type(_RUN_TASK_ID, expdb_test) assert task_type == "Supervised Classification" async def test_db_get_task_evaluation_measure_missing(expdb_test: AsyncConnection) -> None: """get_task_evaluation_measure returns None (not '') when absent.""" - measure = await database.runs.get_task_evaluation_measure(115, expdb_test) + measure = await database.runs.get_task_evaluation_measure(_RUN_TASK_ID, expdb_test) assert measure is None async def test_db_get_uploader_name(user_test: AsyncConnection) -> None: """database.runs.get_uploader_name returns 'Cynthia Glover' for user 1159.""" - name = await database.runs.get_uploader_name(1159, user_test) + name = await database.runs.get_uploader_name(_RUN_UPLOADER_ID, user_test) assert name == "Cynthia Glover" async def test_db_get_uploader_name_missing(user_test: AsyncConnection) -> None: """get_uploader_name returns None for a non-existent user.""" - name = await database.runs.get_uploader_name(_MISSING_RUN_ID, user_test) + name = await database.runs.get_uploader_name(_MISSING_USER_ID, user_test) assert name is None @@ -377,10 +406,10 @@ async def test_get_run_equal( py_normalized = _normalize_py_run(py_response.json()) php_json = php_response.json() - # PHP duplicates evaluation entries natively for each fold, and also provides - # an aggregate with `repeat="0"` and `fold="0"`. The Python API correctly provides - # only the aggregate row (and array_data string). - # To match without complex deepdiff matchers, simply verify the base aggregate entries. + # PHP provides evaluation entries natively for each fold (with `repeat` and `fold` keys) + # as well as an aggregate entry (which might or might not have those keys depending on version). + # To match without complex deepdiff matchers, verify base entries without repeat/fold + # and drop the rest. if ( "run" in php_json and "output_data" in php_json["run"] @@ -391,11 +420,11 @@ async def test_get_run_equal( php_json["run"]["output_data"]["evaluation"] = [ ev for ev in php_evals if "repeat" not in ev and "fold" not in ev ] - elif isinstance(php_evals, dict) and ("repeat" in php_evals or "fold" in php_evals): - # nested_remove_single_element_list removes lists if there's only 1 element, but PHP - # original JSON might have had only 1 base evaluation if no others existed. - # But PHP returns a list anyway if duplicates exist. If they don't, it's a dict. - php_json["run"]["output_data"]["evaluation"] = [] + elif isinstance(php_evals, dict): + if "repeat" in php_evals or "fold" in php_evals: + php_json["run"]["output_data"]["evaluation"] = [] + else: + php_json["run"]["output_data"]["evaluation"] = [php_evals] # PHP sometimes includes empty `error` property instead of an empty list when no error occurred # DeepDiff takes care of it automatically because we didn't see error diffs. @@ -422,6 +451,7 @@ def __init__(self, name: str, value: object, array_data: str | None = None) -> N rows = [ MockRow("float_val", 1.0), MockRow("str_float", "2.0"), + MockRow("str_float_fraction", "1.5"), MockRow("str_text", "not_a_number"), MockRow("unhandled_type", ["list"]), ] @@ -430,7 +460,9 @@ def __init__(self, name: str, value: object, array_data: str | None = None) -> N values = {e.name: e.value for e in evals} expected_one = 1 expected_two = 2 + expected_fraction = 1.5 assert values["float_val"] == expected_one assert values["str_float"] == expected_two + assert values["str_float_fraction"] == expected_fraction assert values["str_text"] is None assert values["unhandled_type"] is None