diff --git a/changelog.d/3391.fixed.md b/changelog.d/3391.fixed.md new file mode 100644 index 000000000..16d799cbe --- /dev/null +++ b/changelog.d/3391.fixed.md @@ -0,0 +1,2 @@ +* Add rerun support for legacy report outputs, including reset of linked simulations and economy cache rows. +* Pin `policyengine` below `1.0` so the legacy API continues to use the compatible `.py` interface. diff --git a/changelog.d/fixed/3394.md b/changelog.d/3394.fixed.md similarity index 100% rename from changelog.d/fixed/3394.md rename to changelog.d/3394.fixed.md diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 93256d778..1a38a394a 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -2,12 +2,100 @@ from werkzeug.exceptions import NotFound, BadRequest import json +from policyengine_api.constants import ( + CURRENT_YEAR, + get_economy_impact_cache_version, +) +from policyengine_api.services.reform_impacts_service import ReformImpactsService from policyengine_api.services.report_output_service import ReportOutputService -from policyengine_api.constants import CURRENT_YEAR +from policyengine_api.services.simulation_service import SimulationService from policyengine_api.utils.payload_validators import validate_country report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +simulation_service = SimulationService() +reform_impacts_service = ReformImpactsService() + + +def _get_linked_simulation_or_raise(country_id: str, simulation_id: int) -> dict: + simulation = simulation_service.get_simulation(country_id, simulation_id) + if simulation is None: + raise BadRequest( + f"Report references simulation #{simulation_id}, but it could not be found for country {country_id}." + ) + return simulation + + +def _load_report_and_linked_simulations( + country_id: str, report_id: int +) -> tuple[dict, dict, dict | None]: + report_output = report_output_service.get_stored_report_output(report_id) + if report_output is None or report_output["country_id"] != country_id: + raise NotFound(f"Report #{report_id} not found.") + + simulation_1 = _get_linked_simulation_or_raise( + country_id=country_id, + simulation_id=report_output["simulation_1_id"], + ) + + simulation_2 = None + if report_output["simulation_2_id"] is not None: + simulation_2 = _get_linked_simulation_or_raise( + country_id=country_id, + simulation_id=report_output["simulation_2_id"], + ) + + if ( + simulation_2 is not None + and simulation_1["population_type"] != simulation_2["population_type"] + ): + raise BadRequest( + f"Report #{report_id} links simulations with mismatched population types." + ) + + return report_output, simulation_1, simulation_2 + + +def _reset_linked_simulations(country_id: str, *simulations: dict | None) -> list[int]: + reset_simulation_ids: list[int] = [] + seen_ids: set[int] = set() + + for simulation in simulations: + if simulation is None or simulation["id"] in seen_ids: + continue + simulation_service.reset_simulation(country_id, simulation["id"]) + seen_ids.add(simulation["id"]) + reset_simulation_ids.append(simulation["id"]) + + return reset_simulation_ids + + +def _delete_economy_cache_for_legacy_report_path( + country_id: str, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, +) -> int | None: + """ + Delete reform_impact rows using the current legacy app path assumptions: + dataset is always "default", options_hash is always "[]", and the report + year maps directly to the economy time period. This is correct for the + current app-generated legacy report flow, not arbitrary historical callers. + """ + return reform_impacts_service.delete_reform_impacts( + country_id=country_id, + policy_id=( + simulation_2["policy_id"] + if simulation_2 is not None + else simulation_1["policy_id"] + ), + baseline_policy_id=simulation_1["policy_id"], + region=simulation_1["population_id"], + dataset="default", + time_period=report_output["year"], + options_hash="[]", + api_version=get_economy_impact_cache_version(country_id), + ) @report_output_bp.route("//report", methods=["POST"]) @@ -197,3 +285,52 @@ def update_report_output(country_id: str) -> Response: except Exception as e: print(f"Error updating report output: {str(e)}") raise BadRequest(f"Failed to update report output: {str(e)}") + + +@report_output_bp.route("//report//rerun", methods=["POST"]) +@validate_country +def rerun_report_output(country_id: str, report_id: int) -> Response: + """ + Reset a legacy report output so the current app can recompute it. + + For economy reports this also purges reform_impact rows using the current + app-path assumptions about dataset/options provenance. + """ + print(f"Rerunning report output {report_id} for country {country_id}") + + report_output, simulation_1, simulation_2 = _load_report_and_linked_simulations( + country_id=country_id, + report_id=report_id, + ) + + report_output_service.reset_report_output(country_id, report_id) + reset_simulation_ids = _reset_linked_simulations( + country_id, simulation_1, simulation_2 + ) + + economy_cache_rows_deleted = 0 + if simulation_1["population_type"] == "geography": + deleted_rows = _delete_economy_cache_for_legacy_report_path( + country_id=country_id, + report_output=report_output, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + economy_cache_rows_deleted = deleted_rows or 0 + + response_body = dict( + status="ok", + message="Report rerun reset successfully", + result=dict( + report_id=report_id, + report_type=simulation_1["population_type"], + simulation_ids=reset_simulation_ids, + economy_cache_rows_deleted=economy_cache_rows_deleted, + ), + ) + + return Response( + json.dumps(response_body), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 0f41352f3..a1d1b522a 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -136,29 +136,60 @@ def delete_reform_impact( dataset, time_period, options_hash, + ): + return self.delete_reform_impacts( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + statuses=("computing",), + ) + + def delete_reform_impacts( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version=None, + statuses=None, ): try: - query = ( + query = [ "DELETE FROM reform_impact WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND " "region = ? AND time_period = ? AND options_hash = ? AND " - "dataset = ? AND status = 'computing'" - ) + "dataset = ?" + ] + params = [ + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + dataset, + ] - local_database.query( - query, - ( - country_id, - policy_id, - baseline_policy_id, - region, - time_period, - options_hash, - dataset, - ), - ) + if api_version is not None: + query.append(" AND api_version = ?") + params.append(api_version) + + if statuses: + placeholders = ", ".join(["?"] * len(statuses)) + query.append(f" AND status IN ({placeholders})") + params.extend(statuses) + + result = local_database.query("".join(query), tuple(params)) + return getattr(result, "rowcount", None) except Exception as e: - print(f"Error deleting reform impact: {str(e)}") + print(f"Error deleting reform impacts: {str(e)}") raise e def set_error_reform_impact( diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 3200ec6e8..7d22d9a73 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -262,3 +262,34 @@ def update_report_output( except Exception as e: print(f"Error updating report output #{report_id}. Details: {str(e)}") raise e + + def reset_report_output(self, country_id: str, report_id: int) -> bool: + """ + Reset a stored report output row back to a pending state. + + This is intentionally separate from update_report_output so rerun paths + can clear persisted output and errors without changing PATCH semantics. + """ + print(f"Resetting report output {report_id}") + + try: + requested_report = self._get_report_output_row(report_id) + if requested_report is None: + raise Exception(f"Report output #{report_id} not found") + + if requested_report["country_id"] != country_id: + raise Exception( + f"Report output #{report_id} does not belong to country {country_id}" + ) + + database.query( + "UPDATE report_outputs SET status = ?, output = NULL, error_message = NULL WHERE id = ?", + ("pending", requested_report["id"]), + ) + + print(f"Successfully reset report output #{report_id}") + return True + + except Exception as e: + print(f"Error resetting report output #{report_id}. Details: {str(e)}") + raise e diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 7b83689e5..9386e683f 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -1,4 +1,3 @@ -import json from sqlalchemy.engine.row import Row from policyengine_api.data import database @@ -193,3 +192,30 @@ def update_simulation( except Exception as e: print(f"Error updating simulation #{simulation_id}. Details: {str(e)}") raise e + + def reset_simulation(self, country_id: str, simulation_id: int) -> bool: + """ + Reset a simulation row back to a pending state and clear persisted + output and errors. + """ + print(f"Resetting simulation {simulation_id}") + api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + + try: + simulation = self.get_simulation( + country_id=country_id, simulation_id=simulation_id + ) + if simulation is None: + raise Exception(f"Simulation #{simulation_id} not found") + + database.query( + "UPDATE simulations SET status = ?, output = NULL, error_message = NULL, api_version = ? WHERE id = ?", + ("pending", api_version, simulation_id), + ) + + print(f"Successfully reset simulation #{simulation_id}") + return True + + except Exception as e: + print(f"Error resetting simulation #{simulation_id}. Details: {str(e)}") + raise e diff --git a/pyproject.toml b/pyproject.toml index d80c55cdc..31579fbd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "policyengine_uk==2.78.0", "policyengine_us==1.633.2", "policyengine_core>=3.16.6", - "policyengine>=0.7.0", + "policyengine>0.12.0,<1", "pydantic", "pymysql", "python-dotenv", diff --git a/tests/unit/routes/test_report_output_routes.py b/tests/unit/routes/test_report_output_routes.py new file mode 100644 index 000000000..727438871 --- /dev/null +++ b/tests/unit/routes/test_report_output_routes.py @@ -0,0 +1,429 @@ +import pytest +from flask import Flask + +from policyengine_api.constants import ( + get_economy_impact_cache_version, + get_report_output_cache_version, +) +from policyengine_api.routes.error_routes import error_bp +from policyengine_api.routes.report_output_routes import report_output_bp + + +@pytest.fixture +def client(): + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(error_bp) + app.register_blueprint(report_output_bp) + + with app.test_client() as test_client: + yield test_client + + +def insert_simulation( + test_db, + *, + country_id="us", + api_version="0.0.0", + population_id="household_1", + population_type="household", + policy_id=1, + status="complete", + output='{"result": true}', + error_message="old error", +): + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status, output, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + country_id, + api_version, + population_id, + population_type, + policy_id, + status, + output, + error_message, + ), + ) + return test_db.query( + "SELECT * FROM simulations ORDER BY id DESC LIMIT 1" + ).fetchone() + + +def insert_report_output( + test_db, + *, + country_id="us", + simulation_1_id, + simulation_2_id=None, + status="complete", + output='{"report": true}', + error_message="old error", + year="2025", +): + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, api_version, status, output, error_message, year) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + country_id, + simulation_1_id, + simulation_2_id, + get_report_output_cache_version(country_id), + status, + output, + error_message, + year, + ), + ) + return test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + +def insert_reform_impact( + test_db, + *, + baseline_policy_id, + reform_policy_id, + country_id="us", + region="us", + dataset="default", + time_period="2025", + options_json="[]", + options_hash="[]", + api_version=None, + reform_impact_json='{"impact": 1}', + status="ok", + message="Completed", + execution_id="exec-1", +): + if api_version is None: + api_version = get_economy_impact_cache_version(country_id) + + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, time_period, + options_json, options_hash, api_version, reform_impact_json, status, message, start_time, + execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, ?)""", + ( + baseline_policy_id, + reform_policy_id, + country_id, + region, + dataset, + time_period, + options_json, + options_hash, + api_version, + reform_impact_json, + status, + message, + execution_id, + ), + ) + + +def test_rerun_report_output_resets_household_report_and_simulation(client, test_db): + simulation = insert_simulation(test_db) + report_output = insert_report_output(test_db, simulation_1_id=simulation["id"]) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "household", + "simulation_ids": [simulation["id"]], + "economy_cache_rows_deleted": 0, + } + + reset_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert reset_report["status"] == "pending" + assert reset_report["output"] is None + assert reset_report["error_message"] is None + + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", (simulation["id"],) + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + + +def test_rerun_report_output_resets_household_comparison_report_and_both_simulations( + client, test_db +): + baseline_simulation = insert_simulation( + test_db, + population_id="household_baseline", + policy_id=20, + ) + reform_simulation = insert_simulation( + test_db, + population_id="household_reform", + policy_id=21, + output='{"result": "comparison"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "household", + "simulation_ids": [baseline_simulation["id"], reform_simulation["id"]], + "economy_cache_rows_deleted": 0, + } + + for simulation_id in (baseline_simulation["id"], reform_simulation["id"]): + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation_id,), + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + + +def test_rerun_report_output_resets_economy_report_and_purges_cache(client, test_db): + baseline_simulation = insert_simulation( + test_db, + population_id="state/ca", + population_type="geography", + policy_id=10, + ) + reform_simulation = insert_simulation( + test_db, + population_id="state/ca", + population_type="geography", + policy_id=11, + output='{"result": "reform"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + ) + + current_version = get_economy_impact_cache_version("us") + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + api_version=current_version, + execution_id="exec-current", + ) + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + api_version="e1stale01", + execution_id="exec-stale", + ) + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + dataset="enhanced_cps", + api_version=current_version, + execution_id="exec-other-dataset", + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "geography", + "simulation_ids": [baseline_simulation["id"], reform_simulation["id"]], + "economy_cache_rows_deleted": 1, + } + + remaining_reform_impacts = test_db.query( + "SELECT execution_id FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_reform_impacts] == [ + "exec-other-dataset", + "exec-stale", + ] + + +def test_rerun_report_output_single_simulation_economy_uses_baseline_policy_for_cache_key( + client, test_db +): + simulation = insert_simulation( + test_db, + population_id="state/ny", + population_type="geography", + policy_id=30, + ) + report_output = insert_report_output(test_db, simulation_1_id=simulation["id"]) + + current_version = get_economy_impact_cache_version("us") + insert_reform_impact( + test_db, + baseline_policy_id=30, + reform_policy_id=30, + region="state/ny", + api_version=current_version, + execution_id="exec-matching", + ) + insert_reform_impact( + test_db, + baseline_policy_id=30, + reform_policy_id=31, + region="state/ny", + api_version=current_version, + execution_id="exec-other-policy", + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "geography", + "simulation_ids": [simulation["id"]], + "economy_cache_rows_deleted": 1, + } + + remaining_reform_impacts = test_db.query( + "SELECT execution_id FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_reform_impacts] == [ + "exec-other-policy" + ] + + +def test_rerun_report_output_missing_report_returns_404(client): + response = client.post("/us/report/999/rerun") + + assert response.status_code == 404 + payload = response.get_json() + assert payload["status"] == "error" + assert payload["result"] is None + assert "Report #999 not found." in payload["message"] + + +def test_rerun_report_output_missing_linked_simulation_returns_400(client, test_db): + report_output = insert_report_output(test_db, simulation_1_id=999) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert payload["result"] is None + assert "references simulation #999" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + + +def test_rerun_report_output_missing_secondary_simulation_does_not_partially_reset( + client, test_db +): + baseline_simulation = insert_simulation( + test_db, + population_id="household_baseline", + policy_id=40, + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=999, + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert "references simulation #999" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + assert unchanged_report["error_message"] == "old error" + + unchanged_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", (baseline_simulation["id"],) + ).fetchone() + assert unchanged_simulation["status"] == "complete" + assert unchanged_simulation["output"] == '{"result": true}' + assert unchanged_simulation["error_message"] == "old error" + + +def test_rerun_report_output_mismatched_population_types_returns_controlled_error( + client, test_db +): + geography_simulation = insert_simulation( + test_db, + population_id="state/tx", + population_type="geography", + policy_id=50, + ) + household_simulation = insert_simulation( + test_db, + population_id="household_mismatch", + population_type="household", + policy_id=51, + output='{"result": "mismatch"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=geography_simulation["id"], + simulation_2_id=household_simulation["id"], + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert "mismatched population types" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + assert unchanged_report["error_message"] == "old error" + + for simulation_id, expected_output in ( + (geography_simulation["id"], '{"result": true}'), + (household_simulation["id"], '{"result": "mismatch"}'), + ): + unchanged_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation_id,), + ).fetchone() + assert unchanged_simulation["status"] == "complete" + assert unchanged_simulation["output"] == expected_output diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py new file mode 100644 index 000000000..ee12327e5 --- /dev/null +++ b/tests/unit/services/test_reform_impacts_service.py @@ -0,0 +1,141 @@ +import datetime + +from policyengine_api.constants import get_economy_impact_cache_version +from policyengine_api.services.reform_impacts_service import ReformImpactsService + + +def insert_reform_impact( + test_db, + *, + baseline_policy_id=1, + reform_policy_id=2, + country_id="us", + region="us", + dataset="default", + time_period="2025", + options_json="[]", + options_hash="[]", + api_version=None, + reform_impact_json='{"result": 1}', + status="ok", + message="Completed", + execution_id="exec-1", +): + if api_version is None: + api_version = get_economy_impact_cache_version(country_id) + + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, time_period, + options_json, options_hash, api_version, reform_impact_json, status, message, start_time, + execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + baseline_policy_id, + reform_policy_id, + country_id, + region, + dataset, + time_period, + options_json, + options_hash, + api_version, + reform_impact_json, + status, + message, + datetime.datetime(2026, 1, 1, 12, 0, 0), + execution_id, + ), + ) + + +class TestReformImpactsService: + def test_delete_reform_impacts_deletes_completed_rows_for_exact_cache_key( + self, test_db + ): + service = ReformImpactsService() + current_version = get_economy_impact_cache_version("us") + + insert_reform_impact( + test_db, + api_version=current_version, + status="ok", + execution_id="exec-ok", + ) + insert_reform_impact( + test_db, + api_version=current_version, + status="error", + execution_id="exec-error", + ) + insert_reform_impact( + test_db, + api_version=current_version, + status="computing", + execution_id="exec-computing", + ) + insert_reform_impact( + test_db, + api_version="e1stale01", + status="ok", + execution_id="exec-stale", + ) + insert_reform_impact( + test_db, + dataset="enhanced_cps", + api_version=current_version, + status="ok", + execution_id="exec-other-dataset", + ) + + deleted_rows = service.delete_reform_impacts( + country_id="us", + policy_id=2, + baseline_policy_id=1, + region="us", + dataset="default", + time_period="2025", + options_hash="[]", + api_version=current_version, + ) + + assert deleted_rows == 3 + + remaining_rows = test_db.query( + "SELECT execution_id, dataset, api_version, status FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_rows] == [ + "exec-other-dataset", + "exec-stale", + ] + + def test_delete_reform_impact_keeps_completed_rows(self, test_db): + service = ReformImpactsService() + + insert_reform_impact( + test_db, + status="ok", + execution_id="exec-ok", + ) + insert_reform_impact( + test_db, + status="computing", + execution_id="exec-computing", + ) + + deleted_rows = service.delete_reform_impact( + country_id="us", + policy_id=2, + baseline_policy_id=1, + region="us", + dataset="default", + time_period="2025", + options_hash="[]", + ) + + assert deleted_rows == 1 + + remaining_rows = test_db.query( + "SELECT execution_id, status FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert remaining_rows == [{"execution_id": "exec-ok", "status": "ok"}] diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index e3b63cbd3..24679d76f 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -4,9 +4,7 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.services.report_output_service import ReportOutputService -from tests.fixtures.services.report_output_fixtures import ( - existing_report_record, -) +pytest_plugins = ("tests.fixtures.services.report_output_fixtures",) service = ReportOutputService() @@ -563,3 +561,59 @@ def test_update_report_output_stale_id_keeps_stale_output_quarantined( assert rows[0]["api_version"] == stale_version assert rows[0]["status"] == "complete" assert rows[0]["output"] == output_json + + +class TestResetReportOutput: + def test_reset_report_output_clears_output_and_error(self, test_db): + output_json = json.dumps({"result": "complete"}) + error_message = "old error" + + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, error_message, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + 11, + None, + "complete", + output_json, + error_message, + get_report_output_cache_version("us"), + "2025", + ), + ) + + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + success = service.reset_report_output( + country_id="us", + report_id=report["id"], + ) + + assert success is True + + reset_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert reset_report["status"] == "pending" + assert reset_report["output"] is None + assert reset_report["error_message"] is None + + def test_reset_report_output_rejects_wrong_country(self, test_db): + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, api_version, year) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", 12, None, "complete", get_report_output_cache_version("us"), "2025"), + ) + + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + with pytest.raises(Exception, match="does not belong to country uk"): + service.reset_report_output(country_id="uk", report_id=report["id"]) diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index ac1fbccf6..dca341ba0 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -1,11 +1,19 @@ +import json + import pytest +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS from policyengine_api.services.simulation_service import SimulationService -from tests.fixtures.services.simulation_fixtures import ( - valid_simulation_data, - existing_simulation_record, -) +pytest_plugins = ("tests.fixtures.services.simulation_fixtures",) + +valid_simulation_data = { + "country_id": "us", + "api_version": "1.0.0", + "population_id": "household_test_123", + "population_type": "household", + "policy_id": 1, +} service = SimulationService() @@ -231,3 +239,59 @@ def test_duplicate_simulation_returns_existing(self, test_db): assert first_simulation["country_id"] == second_simulation["country_id"] assert first_simulation["population_id"] == second_simulation["population_id"] assert first_simulation["policy_id"] == second_simulation["policy_id"] + + +class TestResetSimulation: + def test_reset_simulation_clears_output_and_error(self, test_db): + output_json = json.dumps({"household": {"income": 100}}) + + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status, output, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + "oldvers1", + "household_reset", + "household", + 42, + "complete", + output_json, + "old error", + ), + ) + + simulation = test_db.query( + "SELECT * FROM simulations ORDER BY id DESC LIMIT 1" + ).fetchone() + + success = service.reset_simulation( + country_id="us", + simulation_id=simulation["id"], + ) + + assert success is True + + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation["id"],), + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + assert reset_simulation["api_version"] == COUNTRY_PACKAGE_VERSIONS["us"] + + def test_reset_simulation_requires_matching_country(self, test_db): + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", "oldvers1", "household_reset", "household", 43, "complete"), + ) + + simulation = test_db.query( + "SELECT * FROM simulations ORDER BY id DESC LIMIT 1" + ).fetchone() + + with pytest.raises(Exception, match="Simulation #.* not found"): + service.reset_simulation(country_id="uk", simulation_id=simulation["id"])