diff --git a/changelog.d/protect-ai-routes.fixed.md b/changelog.d/protect-ai-routes.fixed.md new file mode 100644 index 000000000..2df261331 --- /dev/null +++ b/changelog.d/protect-ai-routes.fixed.md @@ -0,0 +1 @@ +Protect AI and tracing routes with the shared API key. diff --git a/policyengine_api/routes/ai_prompt_routes.py b/policyengine_api/routes/ai_prompt_routes.py index a15bd16dd..bb0af70ca 100644 --- a/policyengine_api/routes/ai_prompt_routes.py +++ b/policyengine_api/routes/ai_prompt_routes.py @@ -1,6 +1,7 @@ from flask import Blueprint, Response, request from copy import deepcopy from policyengine_api.services.ai_prompt_service import AIPromptService +from policyengine_api.security import require_simulation_analysis_api_key from policyengine_api.utils.payload_validators import validate_country from policyengine_api.utils.payload_validators.ai import ( validate_sim_analysis_payload, @@ -12,11 +13,12 @@ ai_prompt_service = AIPromptService() -@validate_country @ai_prompt_bp.route( "//ai-prompts/", methods=["POST"], ) +@validate_country +@require_simulation_analysis_api_key def generate_ai_prompt(country_id, prompt_name: str) -> Response: """ Get an AI prompt with a given name, filled with the given data. diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 5157b807d..e97315968 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,6 +1,9 @@ +import json + from flask import Blueprint, request, Response, stream_with_context from werkzeug.exceptions import BadRequest -from policyengine_api.utils.payload_validators import validate_country + +from policyengine_api.security import require_simulation_analysis_api_key from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, ) @@ -10,7 +13,6 @@ from policyengine_api.utils.payload_validators.ai import ( validate_sim_analysis_payload, ) -import json simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() @@ -18,6 +20,7 @@ @simulation_analysis_bp.route("//simulation-analysis", methods=["POST"]) @validate_country +@require_simulation_analysis_api_key def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 3ceb2c343..e7a8cf0f4 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,5 +1,9 @@ +import json + from flask import Blueprint, request, Response, stream_with_context from werkzeug.exceptions import BadRequest + +from policyengine_api.security import require_simulation_analysis_api_key from policyengine_api.utils.payload_validators import ( validate_country, validate_tracer_analysis_payload, @@ -7,9 +11,6 @@ from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) -import json -from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS -import re tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() @@ -17,6 +18,7 @@ @tracer_analysis_bp.route("//tracer-analysis", methods=["POST"]) @validate_country +@require_simulation_analysis_api_key def execute_tracer_analysis(country_id): payload = request.json @@ -28,8 +30,6 @@ def execute_tracer_analysis(country_id): household_id = payload.get("household_id") policy_id = payload.get("policy_id") variable = payload.get("variable") - api_version = COUNTRY_PACKAGE_VERSIONS[country_id] - if not isinstance(variable, str): raise BadRequest("variable must be a string") diff --git a/policyengine_api/security.py b/policyengine_api/security.py new file mode 100644 index 000000000..e24d134fc --- /dev/null +++ b/policyengine_api/security.py @@ -0,0 +1,24 @@ +"""Security helpers for sensitive API routes.""" + +import os +from functools import wraps + +from flask import request +from werkzeug.exceptions import Unauthorized + + +def require_simulation_analysis_api_key(view): + """Require a shared API key for simulation analysis requests.""" + + @wraps(view) + def wrapped(*args, **kwargs): + expected_key = os.getenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "").strip() + if not expected_key: + raise Unauthorized("Simulation analysis API key is not configured") + + if request.headers.get("X-PolicyEngine-Api-Key") == expected_key: + return view(*args, **kwargs) + + raise Unauthorized("API key required for simulation analysis") + + return wrapped diff --git a/tests/conftest.py b/tests/conftest.py index f604176cd..1027d4241 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import time from contextlib import contextmanager from subprocess import Popen, TimeoutExpired -import sys +import os import redis import pytest from policyengine_api.api import app @@ -33,9 +33,18 @@ def running(process_arguments, seconds_to_wait_after_launch=0): def client(): """run the app for the tests to run against""" app.config["TESTING"] = True + previous_api_key = os.environ.get("POLICYENGINE_API_AI_ANALYSIS_API_KEY") + os.environ["POLICYENGINE_API_AI_ANALYSIS_API_KEY"] = "test-ai-analysis-key" with running(["redis-server"], 3): redis_client = redis.Redis() redis_client.ping() with running([sys.executable, "policyengine_api/worker.py"], 3): with app.test_client() as test_client: + test_client.environ_base["HTTP_X_POLICYENGINE_API_KEY"] = ( + "test-ai-analysis-key" + ) yield test_client + if previous_api_key is None: + os.environ.pop("POLICYENGINE_API_AI_ANALYSIS_API_KEY", None) + else: + os.environ["POLICYENGINE_API_AI_ANALYSIS_API_KEY"] = previous_api_key diff --git a/tests/unit/routes/test_ai_route_auth.py b/tests/unit/routes/test_ai_route_auth.py new file mode 100644 index 000000000..984d6f021 --- /dev/null +++ b/tests/unit/routes/test_ai_route_auth.py @@ -0,0 +1,136 @@ +import os +from unittest.mock import patch + +import pytest + +os.environ.setdefault("FLASK_DEBUG", "1") + +from policyengine_api.api import app +from tests.fixtures.simulation_analysis_prompt_fixtures import valid_input_us + + +@pytest.fixture +def client(): + app.config["TESTING"] = True + with app.test_client() as test_client: + yield test_client + + +def test_ai_prompt_rejects_requests_without_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + response = client.post( + "/us/ai-prompts/simulation_analysis", + json=valid_input_us, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 401 + assert "API key required" in response.json["message"] + + +def test_ai_prompt_rejects_loopback_requests_without_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + response = client.post( + "/us/ai-prompts/simulation_analysis", + json=valid_input_us, + environ_base={"REMOTE_ADDR": "127.0.0.1"}, + ) + + assert response.status_code == 401 + assert "API key required" in response.json["message"] + + +def test_ai_prompt_allows_requests_with_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + with patch( + "policyengine_api.routes.ai_prompt_routes.ai_prompt_service.get_prompt", + return_value="Prompt text", + ) as mock_get_prompt: + response = client.post( + "/us/ai-prompts/simulation_analysis", + json=valid_input_us, + headers={"X-PolicyEngine-Api-Key": "secret-key"}, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 200 + assert response.json["result"] == "Prompt text" + mock_get_prompt.assert_called_once() + + +def test_tracer_analysis_rejects_requests_without_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + response = client.post( + "/us/tracer-analysis", + json={ + "household_id": 1500, + "policy_id": 2, + "variable": "disposable_income", + }, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 401 + assert "API key required" in response.json["message"] + + +def test_requests_fail_closed_when_api_key_is_not_configured(client, monkeypatch): + monkeypatch.delenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", raising=False) + + response = client.post( + "/us/tracer-analysis", + json={ + "household_id": 1500, + "policy_id": 2, + "variable": "disposable_income", + }, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 401 + assert "not configured" in response.json["message"] + + +def test_env_flag_does_not_reopen_tracer_analysis(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + monkeypatch.setenv("POLICYENGINE_API_ALLOW_UNAUTHENTICATED_AI_ANALYSIS", "true") + + response = client.post( + "/us/tracer-analysis", + json={ + "household_id": 1500, + "policy_id": 2, + "variable": "disposable_income", + }, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 401 + assert "API key required" in response.json["message"] + + +def test_tracer_analysis_allows_requests_with_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + with patch( + "policyengine_api.routes.tracer_analysis_routes.tracer_analysis_service.execute_analysis", + return_value=("Existing analysis", "static"), + ) as mock_execute_analysis: + response = client.post( + "/us/tracer-analysis", + json={ + "household_id": 1500, + "policy_id": 2, + "variable": "disposable_income", + }, + headers={"X-PolicyEngine-Api-Key": "secret-key"}, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 200 + assert response.json["result"] == "Existing analysis" + mock_execute_analysis.assert_called_once_with("us", 1500, 2, "disposable_income") diff --git a/tests/unit/routes/test_simulation_analysis_auth.py b/tests/unit/routes/test_simulation_analysis_auth.py new file mode 100644 index 000000000..5205a206d --- /dev/null +++ b/tests/unit/routes/test_simulation_analysis_auth.py @@ -0,0 +1,48 @@ +import os +from unittest.mock import patch + +import pytest + +os.environ.setdefault("FLASK_DEBUG", "1") + +from policyengine_api.api import app +from tests.to_refactor.fixtures.simulation_analysis_fixtures import test_json + + +@pytest.fixture +def client(): + app.config["TESTING"] = True + with app.test_client() as test_client: + yield test_client + + +def test_simulation_analysis_rejects_requests_without_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + response = client.post( + "/us/simulation-analysis", + json=test_json, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 401 + assert "API key required" in response.json["message"] + + +def test_simulation_analysis_allows_requests_with_api_key(client, monkeypatch): + monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key") + + with patch( + "policyengine_api.routes.simulation_analysis_routes.simulation_analysis_service.execute_analysis", + return_value=("Existing analysis", "static"), + ) as mock_execute_analysis: + response = client.post( + "/us/simulation-analysis", + json=test_json, + headers={"X-PolicyEngine-Api-Key": "secret-key"}, + environ_base={"REMOTE_ADDR": "203.0.113.10"}, + ) + + assert response.status_code == 200 + assert response.json["result"] == "Existing analysis" + mock_execute_analysis.assert_called_once()