diff --git a/README.md b/README.md index 3132429..e9bb7c9 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,28 @@ Configuration keys: - `"/path/to/ca-bundle.pem"` - custom CA bundle Supporting configs: -- `access.json` – map: topicName -> array of authorized subjects (JWT `sub`). May reside locally or at S3 path referenced by `access_config`. +- `access.json` – controls which users can publish to which topics, with optional message field restrictions. Loaded from the local path or S3 URI set in `access_config`. + + Users can be listed with no restrictions, or with per-field constraints that limit what message content they can produce: + + ```json + { + "topic.runs": { + "service-account-1": {}, + "service-account-2": { + "source_app": ["my-app", "other-app"], + "environment": ["prod"], + "tenant_id": ["^tenant-prod-.*$"] + } + }, + "topic.other": ["service-account-3", "service-account-4"] + } + ``` + + > Note: An empty object (`{}`) means the user has no field restrictions. The list format (`["user1", "user2"]`) is a shorthand where all users are unrestricted. Both formats can coexist in the same file. + > + > When field constraints are configured, all fields must match. Values support regex patterns. + - `topic_schemas/*.json` – each file contains a JSON Schema for a topic. In the current code these are explicitly loaded inside `event_gate_lambda.py`. (Future enhancement: auto-discover or index file.) Environment variables: diff --git a/conf/access.json b/conf/access.json index 827bc3d..bb21ccc 100644 --- a/conf/access.json +++ b/conf/access.json @@ -1,5 +1,13 @@ { - "public.cps.za.runs": ["FooBarUser", "IntegrationTestUser"], + "public.cps.za.runs": { + "FooBarUser": {}, + "IntegrationTestUser": {}, + "RestrictedTestUser": { + "source_app": ["restricted-app"], + "environment": ["dev"], + "tenant_id": ["avms"] + } + }, "public.cps.za.dlchange": ["FooUser", "BarUser"], "public.cps.za.test": ["TestUser"] } diff --git a/src/handlers/handler_topic.py b/src/handlers/handler_topic.py index a0e35f0..d991ef5 100644 --- a/src/handlers/handler_topic.py +++ b/src/handlers/handler_topic.py @@ -19,6 +19,7 @@ import json import logging import os +import re from typing import Any import jwt @@ -28,7 +29,7 @@ from src.handlers.handler_token import HandlerToken from src.utils.conf_path import CONF_DIR -from src.utils.config_loader import load_access_config +from src.utils.config_loader import TopicAccessMap, load_access_config from src.utils.utils import build_error_response from src.writers.writer import Writer @@ -49,7 +50,7 @@ def __init__( self.aws_s3 = aws_s3 self.handler_token = handler_token self.writers = writers - self.access_config: dict[str, list[str]] = {} + self.access_config: TopicAccessMap = {} self.topics: dict[str, dict[str, Any]] = {} def with_load_access_config(self) -> "HandlerTopic": @@ -159,6 +160,10 @@ def _post_topic_message(self, topic_name: str, topic_message: dict[str, Any], to if topic_name not in self.access_config or user not in self.access_config[topic_name]: return build_error_response(403, "auth", "User not authorized for topic") + allowed, perm_error = self._validate_user_permissions(topic_name, user, topic_message) + if not allowed: + return build_error_response(403, "permission", perm_error or "Permission denied") + try: validate(instance=topic_message, schema=self.topics[topic_name]) except ValidationError as exc: @@ -182,3 +187,30 @@ def _post_topic_message(self, topic_name: str, topic_message: dict[str, Any], to "headers": {"Content-Type": "application/json"}, "body": json.dumps({"success": True, "statusCode": 202}), } + + def _validate_user_permissions( + self, + topic_name: str, + user: str, + message: dict[str, Any], + ) -> tuple[bool, str | None]: + """Check message fields against the user's permission constraints. + Args: + topic_name: Target topic name. + user: Authenticated user. + message: Message payload to validate. + Returns: + Tuple of (allowed, error_message). `error_message` is `None` when allowed. + """ + user_permissions = self.access_config[topic_name][user] + if not user_permissions: + return True, None + + for restricted_field, allowed_values in user_permissions.items(): + message_value = message.get(restricted_field) + if message_value is None: + return False, f"Required field '{restricted_field}' missing from message" + if not any(re.fullmatch(allowed, str(message_value)) for allowed in allowed_values): + return False, f"Field '{restricted_field}' value not permitted for user '{user}'" + + return True, None diff --git a/src/readers/reader_postgres.py b/src/readers/reader_postgres.py index 7c5de0c..e20ed16 100644 --- a/src/readers/reader_postgres.py +++ b/src/readers/reader_postgres.py @@ -130,6 +130,8 @@ def read_stats( user=db_config["user"], password=db_config["password"], port=db_config["port"], + connect_timeout=10, + gssencmode="disable", options="-c statement_timeout=30000 -c default_transaction_read_only=on", ) as connection: with connection.cursor() as db_cursor: diff --git a/src/utils/config_loader.py b/src/utils/config_loader.py index a233b4a..a780808 100644 --- a/src/utils/config_loader.py +++ b/src/utils/config_loader.py @@ -25,6 +25,9 @@ logger = logging.getLogger(__name__) +# {topic: {user: {restricted_field: [allowed_values]}}} +TopicAccessMap = dict[str, dict[str, dict[str, list[str]]]] + def load_config(conf_dir: str) -> dict[str, Any]: """Load the main configuration from config.json. @@ -40,32 +43,65 @@ def load_config(conf_dir: str) -> dict[str, Any]: return config -def load_access_config(config: dict[str, Any], aws_s3: ServiceResource) -> dict[str, list[str]]: +def _normalize_access_config(access_data: dict[str, Any]) -> TopicAccessMap: + """Normalize access config to unified internal format. + Converts the legacy list format (`["user1", "user2"]`) to the dict + format (`{"user1": {}, "user2": {}}`) so that all downstream code + can rely on a single structure. + Args: + access_data: Parsed JSON from access config file (mixed list/dict values). + Returns: + Normalized mapping: `{topic: {user: {restricted_field: [allowed_values]}}}` . + """ + result: TopicAccessMap = {} + for topic, value in access_data.items(): + if isinstance(value, list): + # Legacy format: plain user list with no field restrictions + result[topic] = {user: {} for user in value} + elif isinstance(value, dict): + # New format: per-user field constraints already in the expected structure + for user, constraints in value.items(): + if not isinstance(constraints, dict): + raise ValueError( + f"Topic '{topic}', user '{user}': constraints must be a dict, got {type(constraints).__name__}." + ) + for field, patterns in constraints.items(): + if not isinstance(patterns, list): + raise ValueError( + f"Topic '{topic}', user '{user}', field '{field}': patterns must be a list, got {type(patterns).__name__}." + ) + result[topic] = value + else: + raise ValueError(f"Topic '{topic}': expected list or dict, got {type(value).__name__}.") + return result + + +def load_access_config(config: dict[str, Any], aws_s3: ServiceResource) -> TopicAccessMap: """Load access control configuration from S3 or a local file. Args: config: Main configuration dict (must contain `access_config` key). aws_s3: Boto3 S3 resource for loading from S3 paths. Returns: - Dictionary mapping topic names to lists of authorised users. + Normalized mapping of topic names to per-user permission constraints. """ access_path: str = config["access_config"] logger.debug("Loading access configuration from %s.", access_path) - access_config: dict[str, list[str]] = {} + access_data: dict[str, Any] = {} if access_path.startswith("s3://"): name_parts = access_path.split("/") bucket_name = name_parts[2] bucket_object_key = "/".join(name_parts[3:]) - access_config = json.loads( + access_data = json.loads( aws_s3.Bucket(bucket_name).Object(bucket_object_key).get()["Body"].read().decode("utf-8") ) else: with open(access_path, "r", encoding="utf-8") as file: - access_config = json.load(file) + access_data = json.load(file) logger.debug("Loaded access configuration.") - return access_config + return _normalize_access_config(access_data) def load_topic_names(conf_dir: str) -> list[str]: diff --git a/src/writers/writer_postgres.py b/src/writers/writer_postgres.py index c0533b1..5e5a9f7 100644 --- a/src/writers/writer_postgres.py +++ b/src/writers/writer_postgres.py @@ -284,6 +284,8 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N user=db_config["user"], password=db_config["password"], port=db_config["port"], + connect_timeout=10, + gssencmode="disable", ) as connection: with connection.cursor() as cursor: if topic_name == "public.cps.za.dlchange": diff --git a/tests/unit/handlers/test_handler_topic.py b/tests/unit/handlers/test_handler_topic.py index 7ab5394..cb8b107 100644 --- a/tests/unit/handlers/test_handler_topic.py +++ b/tests/unit/handlers/test_handler_topic.py @@ -18,6 +18,7 @@ from unittest.mock import patch, mock_open, MagicMock import jwt +import pytest from src.handlers.handler_topic import HandlerTopic @@ -40,7 +41,7 @@ def test_load_access_config_from_local_file(): result = handler.with_load_access_config() assert result is handler - assert handler.access_config == access_data + assert {"TestUser": {}} == handler.access_config["public.cps.za.test"] def test_load_access_config_from_s3(): @@ -63,7 +64,7 @@ def test_load_access_config_from_s3(): result = handler.with_load_access_config() assert result is handler - assert handler.access_config == access_data + assert {"TestUser": {}} == handler.access_config["public.cps.za.test"] mock_aws_s3.Bucket.assert_called_once_with("my-bucket") mock_aws_s3.Bucket.return_value.Object.assert_called_once_with("path/to/access.json") @@ -259,3 +260,64 @@ def test_token_extraction_lowercase_bearer_header(event_gate_module, make_event, ) resp = event_gate_module.lambda_handler(event) assert 202 == resp["statusCode"] + + +## _validate_user_permissions() +@pytest.mark.parametrize( + "user_perms,payload_updates", + [ + ({}, {}), + ({"source_app": ["app"]}, {}), + ({"tenant_id": ["avms|avm"]}, {"tenant_id": "avms"}), + ({"source_app": ["app"], "environment": ["dev"]}, {}), + ], + ids=["no-restrictions", "exact-match", "regex-match", "multiple-fields"], +) +def test_post_permission_allowed(event_gate_module, make_event, valid_payload, user_perms, payload_updates): + """User with matching permissions can post successfully.""" + valid_payload.update(payload_updates) + with patch.object(event_gate_module.handler_token, "decode_jwt", return_value={"sub": "TestUser"}): + event_gate_module.handler_topic.access_config["public.cps.za.test"] = {"TestUser": user_perms} + for writer in event_gate_module.handler_topic.writers.values(): + writer.write = MagicMock(return_value=(True, None)) + + event = make_event( + "/topics/{topic_name}", + method="POST", + topic="public.cps.za.test", + body=valid_payload, + headers={"Authorization": "Bearer token"}, + ) + resp = event_gate_module.lambda_handler(event) + assert 202 == resp["statusCode"] + + +@pytest.mark.parametrize( + "user_perms,payload_updates,expected_fragment", + [ + ({"environment": ["prod"]}, {}, "environment"), + ({"nonexistent_field": ["val"]}, {}, "nonexistent_field"), + ({"tenant_id": ["avms|avm"]}, {"tenant_id": "xxxx"}, "tenant_id"), + ({"source_app": ["other"], "environment": ["prod"]}, {}, "source_app"), + ], + ids=["value-mismatch", "missing-field", "regex-no-match", "first-constraint-fails"], +) +def test_post_permission_denied( + event_gate_module, make_event, valid_payload, user_perms, payload_updates, expected_fragment +): + """User with non-matching permissions gets 403.""" + valid_payload.update(payload_updates) + with patch.object(event_gate_module.handler_token, "decode_jwt", return_value={"sub": "TestUser"}): + event_gate_module.handler_topic.access_config["public.cps.za.test"] = {"TestUser": user_perms} + event = make_event( + "/topics/{topic_name}", + method="POST", + topic="public.cps.za.test", + body=valid_payload, + headers={"Authorization": "Bearer token"}, + ) + resp = event_gate_module.lambda_handler(event) + assert 403 == resp["statusCode"] + body = json.loads(resp["body"]) + assert "permission" == body["errors"][0]["type"] + assert expected_fragment in body["errors"][0]["message"] diff --git a/tests/unit/test_conf_validation.py b/tests/unit/test_conf_validation.py index 232e968..46078fa 100644 --- a/tests/unit/test_conf_validation.py +++ b/tests/unit/test_conf_validation.py @@ -52,11 +52,14 @@ def test_config_files_have_required_keys(config_files, key): def test_access_json_structure(): path = os.path.join(CONF_DIR, "access.json") data = load_json(path) - assert isinstance(data, dict), "access.json must contain an object mapping topic -> list[user]" + assert isinstance(data, dict), "access.json must contain an object mapping topic -> users" for topic, users in data.items(): assert isinstance(topic, str) - assert isinstance(users, list), f"Topic {topic} value must be a list of users" - assert all(isinstance(u, str) for u in users), f"All users for topic {topic} must be strings" + assert isinstance(users, (list, dict)), f"Topic {topic} value must be a list or dict of users" + if isinstance(users, list): + assert all(isinstance(u, str) for u in users), f"All users for topic {topic} must be strings" + else: + assert all(isinstance(u, str) for u in users), f"All user keys for topic {topic} must be strings" @pytest.mark.parametrize("topic_file", glob(os.path.join(CONF_DIR, "topic_schemas", "*.json"))) diff --git a/tests/unit/test_event_gate_lambda_local_access.py b/tests/unit/test_event_gate_lambda_local_access.py index cf635ee..bc8fdb7 100644 --- a/tests/unit/test_event_gate_lambda_local_access.py +++ b/tests/unit/test_event_gate_lambda_local_access.py @@ -72,4 +72,4 @@ def Bucket(self, name): # noqa: D401 egl_reloaded = importlib.reload(egl) assert not egl_reloaded.config["access_config"].startswith("s3://") # type: ignore[attr-defined] - assert egl_reloaded.handler_topic.access_config["public.cps.za.test"] == ["User"] # type: ignore[attr-defined] + assert {"User": {}} == egl_reloaded.handler_topic.access_config["public.cps.za.test"] # type: ignore[attr-defined] diff --git a/tests/unit/utils/test_config_loader.py b/tests/unit/utils/test_config_loader.py index 6d1aaab..ee134d3 100644 --- a/tests/unit/utils/test_config_loader.py +++ b/tests/unit/utils/test_config_loader.py @@ -18,12 +18,12 @@ import json import os from pathlib import Path -from typing import Any, Dict, List +from typing import Any from unittest.mock import MagicMock import pytest -from src.utils.config_loader import load_access_config, load_config, load_topic_names +from src.utils.config_loader import _normalize_access_config, load_access_config, load_config, load_topic_names @pytest.fixture @@ -74,8 +74,8 @@ def test_loads_from_local_file(self, conf_dir: str) -> None: result = load_access_config(config, aws_s3) - assert ["UserA"] == result["public.cps.za.runs"] - assert ["UserB"] == result["public.cps.za.test"] + assert {"UserA": {}} == result["public.cps.za.runs"] + assert {"UserB": {}} == result["public.cps.za.test"] aws_s3.Bucket.assert_not_called() def test_loads_from_s3(self) -> None: @@ -91,7 +91,7 @@ def test_loads_from_s3(self) -> None: result = load_access_config(config, mock_s3) - assert ["S3User"] == result["public.cps.za.runs"] + assert {"S3User": {}} == result["public.cps.za.runs"] mock_s3.Bucket.assert_called_once_with("my-bucket") mock_s3.Bucket.return_value.Object.assert_called_once_with("conf/access.json") @@ -124,3 +124,50 @@ def test_empty_schemas_dir(self, tmp_path: Path) -> None: result = load_topic_names(str(tmp_path)) assert [] == result + + +class TestNormalizeAccessConfig: + """Tests for _normalize_access_config().""" + + def test_normalize_list_format(self) -> None: + """Test that list format is converted to dict with empty permissions.""" + raw = {"topic": ["user1", "user2"]} + + result = _normalize_access_config(raw) + + assert {"user1": {}, "user2": {}} == result["topic"] + + def test_normalize_dict_format(self) -> None: + """Test that dict format with permissions is preserved as-is.""" + raw = {"topic": {"user1": {"source_app": ["app1"]}, "user2": {}}} + + result = _normalize_access_config(raw) + + assert {"source_app": ["app1"]} == result["topic"]["user1"] + assert {} == result["topic"]["user2"] + + def test_normalize_mixed_format(self) -> None: + """Test that mixed list and dict formats are handled correctly.""" + raw = { + "list.topic": ["userA"], + "dict.topic": {"userB": {"environment": ["prod"]}}, + } + + result = _normalize_access_config(raw) + + assert {"userA": {}} == result["list.topic"] + assert {"environment": ["prod"]} == result["dict.topic"]["userB"] + + @pytest.mark.parametrize( + "raw,error_fragment", + [ + ({"t": "bad"}, "expected list or dict"), + ({"t": {"u": "not-a-dict"}}, "constraints must be a dict"), + ({"t": {"u": {"field": "not-a-list"}}}, "patterns must be a list"), + ], + ids=["invalid-topic-value", "invalid-constraints", "invalid-patterns"], + ) + def test_normalize_rejects_malformed_config(self, raw: dict, error_fragment: str) -> None: + """Test that malformed access config raises ValueError.""" + with pytest.raises(ValueError, match=error_fragment): + _normalize_access_config(raw)