Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,28 @@ Configuration keys:
- `"/path/to/ca-bundle.pem"` - custom CA bundle

Supporting configs:
- `access.json` – map: topicName -> array of authorized subjects (JWT `sub`). May reside locally or at S3 path referenced by `access_config`.
- `access.json` – controls which users can publish to which topics, with optional message field restrictions. Loaded from the local path or S3 URI set in `access_config`.

Users can be listed with no restrictions, or with per-field constraints that limit what message content they can produce:

```json
{
"topic.runs": {
"service-account-1": {},
"service-account-2": {
"source_app": ["my-app", "other-app"],
"environment": ["prod"],
"tenant_id": ["^tenant-prod-.*$"]
}
},
"topic.other": ["service-account-3", "service-account-4"]
}
```

> Note: An empty object (`{}`) means the user has no field restrictions. The list format (`["user1", "user2"]`) is a shorthand where all users are unrestricted. Both formats can coexist in the same file.
>
> When field constraints are configured, all fields must match. Values support regex patterns.

- `topic_schemas/*.json` – each file contains a JSON Schema for a topic. In the current code these are explicitly loaded inside `event_gate_lambda.py`. (Future enhancement: auto-discover or index file.)

Environment variables:
Expand Down
10 changes: 9 additions & 1 deletion conf/access.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
{
"public.cps.za.runs": ["FooBarUser", "IntegrationTestUser"],
"public.cps.za.runs": {
"FooBarUser": {},
"IntegrationTestUser": {},
"RestrictedTestUser": {
"source_app": ["restricted-app"],
"environment": ["dev"],
"tenant_id": ["avms"]
}
},
"public.cps.za.dlchange": ["FooUser", "BarUser"],
"public.cps.za.test": ["TestUser"]
}
36 changes: 34 additions & 2 deletions src/handlers/handler_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import logging
import os
import re
from typing import Any

import jwt
Expand All @@ -28,7 +29,7 @@

from src.handlers.handler_token import HandlerToken
from src.utils.conf_path import CONF_DIR
from src.utils.config_loader import load_access_config
from src.utils.config_loader import TopicAccessMap, load_access_config
from src.utils.utils import build_error_response
from src.writers.writer import Writer

Expand All @@ -49,7 +50,7 @@ def __init__(
self.aws_s3 = aws_s3
self.handler_token = handler_token
self.writers = writers
self.access_config: dict[str, list[str]] = {}
self.access_config: TopicAccessMap = {}
self.topics: dict[str, dict[str, Any]] = {}

def with_load_access_config(self) -> "HandlerTopic":
Expand Down Expand Up @@ -159,6 +160,10 @@ def _post_topic_message(self, topic_name: str, topic_message: dict[str, Any], to
if topic_name not in self.access_config or user not in self.access_config[topic_name]:
return build_error_response(403, "auth", "User not authorized for topic")

allowed, perm_error = self._validate_user_permissions(topic_name, user, topic_message)
if not allowed:
return build_error_response(403, "permission", perm_error or "Permission denied")
Comment thread
tmikula-dev marked this conversation as resolved.

try:
validate(instance=topic_message, schema=self.topics[topic_name])
except ValidationError as exc:
Expand All @@ -182,3 +187,30 @@ def _post_topic_message(self, topic_name: str, topic_message: dict[str, Any], to
"headers": {"Content-Type": "application/json"},
"body": json.dumps({"success": True, "statusCode": 202}),
}

def _validate_user_permissions(
self,
topic_name: str,
user: str,
message: dict[str, Any],
) -> tuple[bool, str | None]:
"""Check message fields against the user's permission constraints.
Args:
topic_name: Target topic name.
user: Authenticated user.
message: Message payload to validate.
Returns:
Tuple of (allowed, error_message). `error_message` is `None` when allowed.
"""
user_permissions = self.access_config[topic_name][user]
if not user_permissions:
return True, None

for restricted_field, allowed_values in user_permissions.items():
message_value = message.get(restricted_field)
if message_value is None:
return False, f"Required field '{restricted_field}' missing from message"
if not any(re.fullmatch(allowed, str(message_value)) for allowed in allowed_values):
return False, f"Field '{restricted_field}' value not permitted for user '{user}'"

return True, None
2 changes: 2 additions & 0 deletions src/readers/reader_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def read_stats(
user=db_config["user"],
password=db_config["password"],
port=db_config["port"],
connect_timeout=10,
gssencmode="disable",
options="-c statement_timeout=30000 -c default_transaction_read_only=on",
) as connection:
with connection.cursor() as db_cursor:
Expand Down
48 changes: 42 additions & 6 deletions src/utils/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

logger = logging.getLogger(__name__)

# {topic: {user: {restricted_field: [allowed_values]}}}
TopicAccessMap = dict[str, dict[str, dict[str, list[str]]]]


def load_config(conf_dir: str) -> dict[str, Any]:
"""Load the main configuration from config.json.
Expand All @@ -40,32 +43,65 @@ def load_config(conf_dir: str) -> dict[str, Any]:
return config


def load_access_config(config: dict[str, Any], aws_s3: ServiceResource) -> dict[str, list[str]]:
def _normalize_access_config(access_data: dict[str, Any]) -> TopicAccessMap:
"""Normalize access config to unified internal format.
Converts the legacy list format (`["user1", "user2"]`) to the dict
format (`{"user1": {}, "user2": {}}`) so that all downstream code
can rely on a single structure.
Args:
access_data: Parsed JSON from access config file (mixed list/dict values).
Returns:
Normalized mapping: `{topic: {user: {restricted_field: [allowed_values]}}}` .
"""
result: TopicAccessMap = {}
for topic, value in access_data.items():
if isinstance(value, list):
# Legacy format: plain user list with no field restrictions
result[topic] = {user: {} for user in value}
elif isinstance(value, dict):
# New format: per-user field constraints already in the expected structure
for user, constraints in value.items():
if not isinstance(constraints, dict):
raise ValueError(
f"Topic '{topic}', user '{user}': constraints must be a dict, got {type(constraints).__name__}."
)
for field, patterns in constraints.items():
if not isinstance(patterns, list):
raise ValueError(
f"Topic '{topic}', user '{user}', field '{field}': patterns must be a list, got {type(patterns).__name__}."
)
result[topic] = value
else:
raise ValueError(f"Topic '{topic}': expected list or dict, got {type(value).__name__}.")
return result
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def load_access_config(config: dict[str, Any], aws_s3: ServiceResource) -> TopicAccessMap:
"""Load access control configuration from S3 or a local file.
Args:
config: Main configuration dict (must contain `access_config` key).
aws_s3: Boto3 S3 resource for loading from S3 paths.
Returns:
Dictionary mapping topic names to lists of authorised users.
Normalized mapping of topic names to per-user permission constraints.
"""
access_path: str = config["access_config"]
logger.debug("Loading access configuration from %s.", access_path)

access_config: dict[str, list[str]] = {}
access_data: dict[str, Any] = {}

if access_path.startswith("s3://"):
name_parts = access_path.split("/")
bucket_name = name_parts[2]
bucket_object_key = "/".join(name_parts[3:])
access_config = json.loads(
access_data = json.loads(
aws_s3.Bucket(bucket_name).Object(bucket_object_key).get()["Body"].read().decode("utf-8")
)
else:
with open(access_path, "r", encoding="utf-8") as file:
access_config = json.load(file)
access_data = json.load(file)

logger.debug("Loaded access configuration.")
return access_config
return _normalize_access_config(access_data)


def load_topic_names(conf_dir: str) -> list[str]:
Expand Down
2 changes: 2 additions & 0 deletions src/writers/writer_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N
user=db_config["user"],
password=db_config["password"],
port=db_config["port"],
connect_timeout=10,
gssencmode="disable",
) as connection:
with connection.cursor() as cursor:
if topic_name == "public.cps.za.dlchange":
Expand Down
66 changes: 64 additions & 2 deletions tests/unit/handlers/test_handler_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest.mock import patch, mock_open, MagicMock

import jwt
import pytest

from src.handlers.handler_topic import HandlerTopic

Expand All @@ -40,7 +41,7 @@ def test_load_access_config_from_local_file():
result = handler.with_load_access_config()

assert result is handler
assert handler.access_config == access_data
assert {"TestUser": {}} == handler.access_config["public.cps.za.test"]


def test_load_access_config_from_s3():
Expand All @@ -63,7 +64,7 @@ def test_load_access_config_from_s3():
result = handler.with_load_access_config()

assert result is handler
assert handler.access_config == access_data
assert {"TestUser": {}} == handler.access_config["public.cps.za.test"]
mock_aws_s3.Bucket.assert_called_once_with("my-bucket")
mock_aws_s3.Bucket.return_value.Object.assert_called_once_with("path/to/access.json")

Expand Down Expand Up @@ -259,3 +260,64 @@ def test_token_extraction_lowercase_bearer_header(event_gate_module, make_event,
)
resp = event_gate_module.lambda_handler(event)
assert 202 == resp["statusCode"]


## _validate_user_permissions()
@pytest.mark.parametrize(
"user_perms,payload_updates",
[
({}, {}),
({"source_app": ["app"]}, {}),
({"tenant_id": ["avms|avm"]}, {"tenant_id": "avms"}),
({"source_app": ["app"], "environment": ["dev"]}, {}),
],
ids=["no-restrictions", "exact-match", "regex-match", "multiple-fields"],
)
def test_post_permission_allowed(event_gate_module, make_event, valid_payload, user_perms, payload_updates):
"""User with matching permissions can post successfully."""
valid_payload.update(payload_updates)
with patch.object(event_gate_module.handler_token, "decode_jwt", return_value={"sub": "TestUser"}):
event_gate_module.handler_topic.access_config["public.cps.za.test"] = {"TestUser": user_perms}
for writer in event_gate_module.handler_topic.writers.values():
writer.write = MagicMock(return_value=(True, None))

event = make_event(
"/topics/{topic_name}",
method="POST",
topic="public.cps.za.test",
body=valid_payload,
headers={"Authorization": "Bearer token"},
)
resp = event_gate_module.lambda_handler(event)
assert 202 == resp["statusCode"]


@pytest.mark.parametrize(
"user_perms,payload_updates,expected_fragment",
[
({"environment": ["prod"]}, {}, "environment"),
({"nonexistent_field": ["val"]}, {}, "nonexistent_field"),
({"tenant_id": ["avms|avm"]}, {"tenant_id": "xxxx"}, "tenant_id"),
({"source_app": ["other"], "environment": ["prod"]}, {}, "source_app"),
],
ids=["value-mismatch", "missing-field", "regex-no-match", "first-constraint-fails"],
)
def test_post_permission_denied(
event_gate_module, make_event, valid_payload, user_perms, payload_updates, expected_fragment
):
"""User with non-matching permissions gets 403."""
valid_payload.update(payload_updates)
with patch.object(event_gate_module.handler_token, "decode_jwt", return_value={"sub": "TestUser"}):
event_gate_module.handler_topic.access_config["public.cps.za.test"] = {"TestUser": user_perms}
event = make_event(
Comment thread
tmikula-dev marked this conversation as resolved.
"/topics/{topic_name}",
method="POST",
topic="public.cps.za.test",
body=valid_payload,
headers={"Authorization": "Bearer token"},
)
resp = event_gate_module.lambda_handler(event)
assert 403 == resp["statusCode"]
body = json.loads(resp["body"])
assert "permission" == body["errors"][0]["type"]
assert expected_fragment in body["errors"][0]["message"]
9 changes: 6 additions & 3 deletions tests/unit/test_conf_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def test_config_files_have_required_keys(config_files, key):
def test_access_json_structure():
path = os.path.join(CONF_DIR, "access.json")
data = load_json(path)
assert isinstance(data, dict), "access.json must contain an object mapping topic -> list[user]"
assert isinstance(data, dict), "access.json must contain an object mapping topic -> users"
for topic, users in data.items():
assert isinstance(topic, str)
assert isinstance(users, list), f"Topic {topic} value must be a list of users"
assert all(isinstance(u, str) for u in users), f"All users for topic {topic} must be strings"
assert isinstance(users, (list, dict)), f"Topic {topic} value must be a list or dict of users"
if isinstance(users, list):
assert all(isinstance(u, str) for u in users), f"All users for topic {topic} must be strings"
else:
assert all(isinstance(u, str) for u in users), f"All user keys for topic {topic} must be strings"


@pytest.mark.parametrize("topic_file", glob(os.path.join(CONF_DIR, "topic_schemas", "*.json")))
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_event_gate_lambda_local_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ def Bucket(self, name): # noqa: D401
egl_reloaded = importlib.reload(egl)

assert not egl_reloaded.config["access_config"].startswith("s3://") # type: ignore[attr-defined]
assert egl_reloaded.handler_topic.access_config["public.cps.za.test"] == ["User"] # type: ignore[attr-defined]
assert {"User": {}} == egl_reloaded.handler_topic.access_config["public.cps.za.test"] # type: ignore[attr-defined]
Loading