-
Notifications
You must be signed in to change notification settings - Fork 0
PostgreSQL connection pooling #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| try: | ||
| import psycopg2 | ||
| from psycopg2 import Error as PsycopgError | ||
| from psycopg2 import OperationalError | ||
| from psycopg2 import sql as psycopg2_sql | ||
| except ImportError: | ||
| psycopg2 = None # type: ignore | ||
|
|
@@ -42,6 +43,9 @@ | |
| class PsycopgError(Exception): # type: ignore | ||
| """Shim psycopg2 error base when psycopg2 is not installed.""" | ||
|
|
||
| class OperationalError(PsycopgError): # type: ignore | ||
| """Shim psycopg2 OperationalError when psycopg2 is not installed.""" | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -70,6 +74,7 @@ def __init__(self) -> None: | |
| self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") | ||
| self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") | ||
| self._db_config: dict[str, Any] | None = None | ||
| self._connection: Any | None = None | ||
| logger.debug("Initialized PostgreSQL reader.") | ||
|
|
||
| def _load_db_config(self) -> dict[str, Any]: | ||
|
|
@@ -81,6 +86,22 @@ def _load_db_config(self) -> dict[str, Any]: | |
| raise RuntimeError("Failed to load database configuration.") | ||
| return config | ||
|
|
||
| def _get_connection(self) -> Any: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty common pattern actually - there is class DB:
@cached_property
def conn(self):
return create_connection()There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and it should be also lazy by default, so the connection would be established on the first call, not eagerly |
||
| """Return a cached database connection, creating one if needed.""" | ||
| if self._connection is not None and not self._connection.closed: | ||
| return self._connection | ||
| db_config = self._load_db_config() | ||
| self._connection = psycopg2.connect( # type: ignore[attr-defined] | ||
| database=db_config["database"], | ||
| host=db_config["host"], | ||
| user=db_config["user"], | ||
| password=db_config["password"], | ||
| port=db_config["port"], | ||
| options="-c statement_timeout=30000 -c default_transaction_read_only=on", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider to extract this 30s timeout somewhere into a constant |
||
| ) | ||
| logger.debug("New PostgreSQL reader connection established.") | ||
| return self._connection | ||
|
|
||
| def read_stats( | ||
| self, | ||
| timestamp_start: int | None = None, | ||
|
|
@@ -124,20 +145,23 @@ def read_stats( | |
| params.append(limit + 1) | ||
|
|
||
| try: | ||
| with psycopg2.connect( # type: ignore[attr-defined] | ||
| database=db_config["database"], | ||
| host=db_config["host"], | ||
| user=db_config["user"], | ||
| password=db_config["password"], | ||
| port=db_config["port"], | ||
| options="-c statement_timeout=30000 -c default_transaction_read_only=on", | ||
| ) as connection: | ||
| with connection.cursor() as db_cursor: | ||
| db_cursor.execute(query, params) | ||
| col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] | ||
| raw_rows = db_cursor.fetchall() | ||
| for attempt in range(2): | ||
| try: | ||
| connection = self._get_connection() | ||
| with connection.cursor() as db_cursor: | ||
| db_cursor.execute(query, params) | ||
| col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] | ||
| raw_rows = db_cursor.fetchall() | ||
| connection.rollback() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rollback on read? what am I missing here? |
||
| break | ||
| except OperationalError as exc: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this method does way too much. It loads the DB config, validates it, performs retries, manipulates with cursor, unpacks and post-processes the values. Split it please, it's hard to read, hard to test, hard to extend. My advice: you try first. Think about it, reason about it in your head - what are the responsibilities of individual logical execution code blocks? You will need more methods - how much, it might tell you the previous question (if it's too much, say 3-4+ and it might smell like a separated moderately-or-bigger piece, maybe used on at least 2 places, not just one, put it onto a class). Then, maybe consider Copilot with Opus 4.6/4.7 to refactor it and see what happens. Then, learn from it, and consider to ask repetitively but now with more specifications/exactness about how can it refactor it - you provide the 'tutorial' for it. Etc. This is what I am doing with architectural questions, refactoring tasks, even when coding projects completely from scratch. I think this is one of the skills to have. AI will do what you tell it to do |
||
| self._connection = None | ||
| if attempt > 0: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What am I missing here - is the retry even working? Would it not fail after the first attempt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or is it that the RUntimeError is non-retrieable? |
||
| raise RuntimeError(f"Database connection failed after retry: {exc}") from exc | ||
| logger.warning("PostgreSQL connection lost, reconnecting.") | ||
| except PsycopgError as exc: | ||
| raise RuntimeError(f"Database query failed: {exc}") from exc | ||
| self._connection = None | ||
| raise RuntimeError(f"Database query error: {exc}") from exc | ||
|
|
||
| rows = [dict(zip(col_names, row, strict=True)) for row in raw_rows] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| try: | ||
| import psycopg2 | ||
| from psycopg2 import Error as PsycopgError | ||
| from psycopg2 import OperationalError | ||
| from psycopg2 import sql as psycopg2_sql | ||
| except ImportError: | ||
| psycopg2 = None # type: ignore | ||
|
|
@@ -39,6 +40,9 @@ | |
| class PsycopgError(Exception): # type: ignore | ||
| """Shim psycopg2 error base when psycopg2 is not installed.""" | ||
|
|
||
| class OperationalError(PsycopgError): # type: ignore | ||
| """Shim psycopg2 OperationalError when psycopg2 is not installed.""" | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -53,6 +57,7 @@ def __init__(self, config: dict[str, Any]) -> None: | |
| self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") | ||
| self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") | ||
| self._db_config: dict[str, Any | None] | None = None | ||
| self._connection: Any | None = None | ||
| logger.debug("Initialized PostgreSQL writer.") | ||
|
|
||
| def _load_db_config(self) -> dict[str, Any]: | ||
|
|
@@ -61,6 +66,21 @@ def _load_db_config(self) -> dict[str, Any]: | |
| self._db_config = load_postgres_config(self._secret_name, self._secret_region) | ||
| return self._db_config # type: ignore[return-value] | ||
|
|
||
| def _get_connection(self) -> Any: | ||
| """Return a cached database connection, creating one if needed.""" | ||
| if self._connection is not None and not self._connection.closed: | ||
| return self._connection | ||
| db_config = self._load_db_config() | ||
| self._connection = psycopg2.connect( # type: ignore[attr-defined] | ||
| database=db_config["database"], | ||
| host=db_config["host"], | ||
| user=db_config["user"], | ||
| password=db_config["password"], | ||
| port=db_config["port"], | ||
| ) | ||
| logger.debug("New PostgreSQL writer connection established.") | ||
| return self._connection | ||
|
|
||
| def _postgres_edla_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: | ||
| """Insert a dlchange style event row. | ||
| Args: | ||
|
|
@@ -278,23 +298,25 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N | |
|
|
||
| table_info = TOPIC_TABLE_MAP[topic_name] | ||
|
|
||
| with psycopg2.connect( # type: ignore[attr-defined] | ||
| database=db_config["database"], | ||
| host=db_config["host"], | ||
| user=db_config["user"], | ||
| password=db_config["password"], | ||
| port=db_config["port"], | ||
| ) as connection: | ||
| with connection.cursor() as cursor: | ||
| if topic_name == "public.cps.za.dlchange": | ||
| self._postgres_edla_write(cursor, table_info["main"], message) | ||
| elif topic_name == "public.cps.za.runs": | ||
| self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) | ||
| elif topic_name == "public.cps.za.test": | ||
| self._postgres_test_write(cursor, table_info["main"], message) | ||
|
|
||
| connection.commit() | ||
| for attempt in range(2): | ||
| try: | ||
| connection = self._get_connection() | ||
| with connection.cursor() as cursor: | ||
| if topic_name == "public.cps.za.dlchange": | ||
| self._postgres_edla_write(cursor, table_info["main"], message) | ||
| elif topic_name == "public.cps.za.runs": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider extracting these topic names into a common constant - maybe a frozen data class or so |
||
| self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) | ||
| elif topic_name == "public.cps.za.test": | ||
| self._postgres_test_write(cursor, table_info["main"], message) | ||
| connection.commit() | ||
| break | ||
| except OperationalError: | ||
| self._connection = None | ||
| if attempt > 0: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider improving logs - similar way how the Reader is done. In fact, maybe you can extract some of it into a common class handling connections and retries. Just a quick idea, I would need to think about it more, just consider it. I know that the conn setting for reading and writing is slightly different (but that could be parametrized or so) |
||
| raise | ||
| logger.warning("PostgreSQL connection lost, reconnecting.") | ||
| except (RuntimeError, PsycopgError, BotoCoreError, ClientError, ValueError, KeyError) as e: | ||
| self._connection = None | ||
| err_msg = f"The Postgres writer failed with unknown error: {str(e)}" | ||
| logger.exception(err_msg) | ||
| return False, err_msg | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| # | ||
| # Copyright 2026 ABSA Group Limited | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
|
|
||
| import time | ||
| import uuid | ||
|
|
||
| import pytest | ||
|
|
||
| from tests.integration.conftest import EventGateTestClient, EventStatsTestClient | ||
|
|
||
|
|
||
| def _make_test_event() -> dict: | ||
| """Build a minimal runs event payload.""" | ||
| now_ms = int(time.time() * 1000) | ||
| return { | ||
| "event_id": str(uuid.uuid4()), | ||
| "job_ref": f"conn-reuse-{uuid.uuid4().hex[:8]}", | ||
| "tenant_id": "CONN_REUSE_TEST", | ||
| "source_app": "integration-conn-reuse", | ||
| "source_app_version": "1.0.0", | ||
| "environment": "test", | ||
| "timestamp_start": now_ms - 60000, | ||
| "timestamp_end": now_ms, | ||
| "jobs": [ | ||
| { | ||
| "catalog_id": "db.schema.conn_reuse_table", | ||
| "status": "succeeded", | ||
| "timestamp_start": now_ms - 60000, | ||
| "timestamp_end": now_ms, | ||
| } | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| class TestWriterConnectionReuse: | ||
| """Verify that WriterPostgres reuses connections across invocations.""" | ||
|
|
||
| @pytest.fixture(scope="class", autouse=True) | ||
| def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: | ||
| """Post events so the writer connection is established.""" | ||
| for _ in range(2): | ||
| event = _make_test_event() | ||
| response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) | ||
| assert 202 == response["statusCode"] | ||
|
|
||
| def test_writer_reuses_connection_across_writes( | ||
| self, seed_events: None, eventgate_client: EventGateTestClient, valid_token: str | ||
| ) -> None: | ||
| """Test that subsequent writes reuse the same cached connection.""" | ||
| from src.event_gate_lambda import writers | ||
|
|
||
| writer = writers["postgres"] | ||
| conn_before = writer._connection | ||
| assert conn_before is not None | ||
| assert 0 == conn_before.closed | ||
|
|
||
| event = _make_test_event() | ||
| response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) | ||
| assert 202 == response["statusCode"] | ||
| assert conn_before is writer._connection | ||
|
|
||
|
|
||
| class TestReaderConnectionReuse: | ||
| """Verify that ReaderPostgres reuses connections across invocations.""" | ||
|
|
||
| @pytest.fixture(scope="class", autouse=True) | ||
| def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: | ||
| """Seed events so stats queries return data.""" | ||
| for _ in range(2): | ||
| event = _make_test_event() | ||
| response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) | ||
| assert 202 == response["statusCode"] | ||
|
|
||
| def test_reader_reuses_connection_across_reads(self, seed_events: None, stats_client: EventStatsTestClient) -> None: | ||
| """Test that successive queries reuse the same cached connection.""" | ||
| from src.event_stats_lambda import reader_postgres | ||
|
|
||
| stats_client.post_stats("public.cps.za.runs", {}) | ||
|
|
||
| conn_after_first = reader_postgres._connection | ||
| assert conn_after_first is not None | ||
| assert 0 == conn_after_first.closed | ||
|
|
||
| stats_client.post_stats("public.cps.za.runs", {}) | ||
| assert conn_after_first is reader_postgres._connection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we would use strict mode of
mypy, I am not sure this would be allowed / without warning. It's nice that you use types, but having a type that has an optional Any value does not say that much