From 62b8c26059f7bb12689ab8e644415df5d0990cb6 Mon Sep 17 00:00:00 2001 From: Cristopher Pinzon Date: Mon, 23 Mar 2026 16:43:55 -0500 Subject: [PATCH 1/3] Feat: Add support for SSL --- postgresql_proxy/proxy.py | 129 ++++++++++++++++++++++++++++---------- 1 file changed, 97 insertions(+), 32 deletions(-) diff --git a/postgresql_proxy/proxy.py b/postgresql_proxy/proxy.py index 1417db9..cacaa1c 100644 --- a/postgresql_proxy/proxy.py +++ b/postgresql_proxy/proxy.py @@ -41,7 +41,7 @@ class SelectorKeyProxy(selectors.SelectorKey): class Proxy(object): - def __init__(self, instance_config, plugins, debug=False): + def __init__(self, instance_config, plugins, debug=False, ssl_context=None): self.plugins = plugins self.num_clients = 0 self.instance_config = instance_config @@ -49,6 +49,8 @@ def __init__(self, instance_config, plugins, debug=False): self.selector = selectors.DefaultSelector() self.running = True self.sock = None + self.ssl_context = ssl_context + # this is used to track leftover sockets self._debug = debug if self._debug: @@ -114,41 +116,104 @@ def accept_wrapper(self, sock: socket.socket): :param sock: the client socket :return: """ - clientsocket, address = sock.accept() # Should be ready to - clientsocket.setblocking(False) - self.num_clients += 1 - sock_name = '{}_{}'.format(self.instance_config.listen.name, self.num_clients) - LOG.info("connection from %s, connection initiated %s", address, sock_name) - events = selectors.EVENT_READ - - # Context dictionary, for sharing state data, connection details, which might be useful for interceptors - context = { - 'instance_config': self.instance_config - } - - # create a Connection object, representing the relation between a proxied client to postgres - conn = connection.Connection( - clientsocket, - name=sock_name, - address=address, - events=events, - context=context - ) + try: + # Accept the raw connection + clientsocket, address = sock.accept() + + # Check if SSL is enabled for this proxy + if self.ssl_context: + # Handle SSL negotiation - must happen before setblocking(False) + clientsocket = _handle_ssl_negotiation(clientsocket, self.ssl_context) + + clientsocket.setblocking(False) + self.num_clients += 1 + sock_name = f"{self.instance_config.listen.name}_{self.num_clients}" + LOG.info( + "Connection from %s, connection initiated %s (SSL: %s)", + address, + sock_name, + ssl_context is not None, + ) + + events = selectors.EVENT_READ + context = {"instance_config": self.instance_config} + + conn = pg_connection.Connection( + clientsocket, + name=sock_name, + address=address, + events=events, + context=context, + ) + + pg_conn = self._create_pg_connection(address, context) + + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.responses is not None + ): + pg_conn.interceptor = ResponseInterceptor( + self.instance_config.intercept.responses, self.plugins, context + ) + pg_conn.redirect_conn = conn + + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.commands is not None + ): + conn.interceptor = CommandInterceptor( + self.instance_config.intercept.commands, self.plugins, context + ) + conn.redirect_conn = pg_conn + + self._register_conn(conn) + self._register_conn(pg_conn) + + except ConnectionRefusedError: + LOG.debug("Connection refused in Postgres proxy server - instance not (yet) available") + except Exception as e: + LOG.warning("Error accepting connection in Postgres proxy: %s", e) - # create the connection to Postgres - pg_conn = self._create_pg_connection(address, context) + def _handle_ssl_negotiation( + client_socket: socket.socket, ssl_context: ssl.SSLContext + ) -> socket.socket: + """ + Handle PostgreSQL SSL negotiation on an accepted socket. - if self.instance_config.intercept is not None and self.instance_config.intercept.responses is not None: - pg_conn.interceptor = ResponseInterceptor(self.instance_config.intercept.responses, self.plugins, context) - pg_conn.redirect_conn = conn + PostgreSQL SSL flow: + 1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4) + 2. Server responds 'S' (SSL supported) or 'N' (not supported) + 3. If 'S', TLS handshake follows + 4. After TLS, normal PostgreSQL protocol begins - if self.instance_config.intercept is not None and self.instance_config.intercept.commands is not None: - conn.interceptor = CommandInterceptor(self.instance_config.intercept.commands, self.plugins, context) - conn.redirect_conn = pg_conn + Returns the SSL-wrapped socket if negotiation succeeds, or the original socket. + """ + try: + # Peek at the first 8 bytes to check for SSLRequest + # Using MSG_PEEK so we don't consume the data if it's not SSLRequest + client_socket.setblocking(True) + data = client_socket.recv(8, socket.MSG_PEEK) + + if len(data) == 8: + length = int.from_bytes(data[:4], "big") + code = int.from_bytes(data[4:8], "big") + + if length == 8 and code == 80877103: # SSLRequest code + # Consume the SSLRequest + client_socket.recv(8) + # Send 'S' to indicate SSL is supported + client_socket.send(b"S") + # Wrap socket with SSL + ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) + LOG.debug("SSL handshake completed for PostgreSQL connection") + return ssl_socket + + # Not an SSLRequest, return original socket + return client_socket - # Register both connections to be watched by the selector - self._register_conn(conn) - self._register_conn(pg_conn) + except Exception as e: + LOG.debug("Error during PostgreSQL SSL negotiation: %s", e) + return client_socket def service_connection(self, key: SelectorKeyProxy, mask): """ From 22b789a51984dff94efd96e51dd56f0d50406478 Mon Sep 17 00:00:00 2001 From: Cristopher Pinzon Date: Mon, 23 Mar 2026 17:33:43 -0500 Subject: [PATCH 2/3] fixes --- postgresql_proxy/proxy.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/postgresql_proxy/proxy.py b/postgresql_proxy/proxy.py index cacaa1c..fad632d 100644 --- a/postgresql_proxy/proxy.py +++ b/postgresql_proxy/proxy.py @@ -27,6 +27,7 @@ import logging import selectors import socket +import ssl from postgresql_proxy import connection, config_schema as cfg from postgresql_proxy.interceptors import ResponseInterceptor, CommandInterceptor @@ -123,7 +124,7 @@ def accept_wrapper(self, sock: socket.socket): # Check if SSL is enabled for this proxy if self.ssl_context: # Handle SSL negotiation - must happen before setblocking(False) - clientsocket = _handle_ssl_negotiation(clientsocket, self.ssl_context) + clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context) clientsocket.setblocking(False) self.num_clients += 1 @@ -132,13 +133,13 @@ def accept_wrapper(self, sock: socket.socket): "Connection from %s, connection initiated %s (SSL: %s)", address, sock_name, - ssl_context is not None, + self.ssl_context is not None, ) events = selectors.EVENT_READ context = {"instance_config": self.instance_config} - conn = pg_connection.Connection( + conn = connection.Connection( clientsocket, name=sock_name, address=address, @@ -175,7 +176,7 @@ def accept_wrapper(self, sock: socket.socket): LOG.warning("Error accepting connection in Postgres proxy: %s", e) def _handle_ssl_negotiation( - client_socket: socket.socket, ssl_context: ssl.SSLContext + self, client_socket: socket.socket, ssl_context: ssl.SSLContext ) -> socket.socket: """ Handle PostgreSQL SSL negotiation on an accepted socket. From 6a997f3aae38cfc8f19ef22fd3b7f57e1640c2c0 Mon Sep 17 00:00:00 2001 From: Cristopher Pinzon Date: Thu, 26 Mar 2026 13:12:06 -0500 Subject: [PATCH 3/3] remove trycatch and add typing --- postgresql_proxy/proxy.py | 158 +++++++++++++++++++------------------- 1 file changed, 79 insertions(+), 79 deletions(-) diff --git a/postgresql_proxy/proxy.py b/postgresql_proxy/proxy.py index fad632d..ef4ad19 100644 --- a/postgresql_proxy/proxy.py +++ b/postgresql_proxy/proxy.py @@ -24,10 +24,14 @@ interceptors.py - intercepting for modification ''' +from __future__ import annotations + import logging import selectors import socket import ssl +from types import ModuleType + from postgresql_proxy import connection, config_schema as cfg from postgresql_proxy.interceptors import ResponseInterceptor, CommandInterceptor @@ -42,7 +46,13 @@ class SelectorKeyProxy(selectors.SelectorKey): class Proxy(object): - def __init__(self, instance_config, plugins, debug=False, ssl_context=None): + def __init__( + self, + instance_config: cfg.InstanceSettings, + plugins: dict[str, ModuleType], + debug: bool = False, + ssl_context: ssl.SSLContext | None = None, + ) -> None: self.plugins = plugins self.num_clients = 0 self.instance_config = instance_config @@ -117,63 +127,58 @@ def accept_wrapper(self, sock: socket.socket): :param sock: the client socket :return: """ - try: - # Accept the raw connection - clientsocket, address = sock.accept() - - # Check if SSL is enabled for this proxy - if self.ssl_context: - # Handle SSL negotiation - must happen before setblocking(False) - clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context) - - clientsocket.setblocking(False) - self.num_clients += 1 - sock_name = f"{self.instance_config.listen.name}_{self.num_clients}" - LOG.info( - "Connection from %s, connection initiated %s (SSL: %s)", - address, - sock_name, - self.ssl_context is not None, - ) - events = selectors.EVENT_READ - context = {"instance_config": self.instance_config} + # Accept the raw connection + clientsocket, address = sock.accept() + + # Check if SSL is enabled for this proxy + if self.ssl_context: + # Handle SSL negotiation - must happen before setblocking(False) + clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context) + + clientsocket.setblocking(False) + self.num_clients += 1 + sock_name = f"{self.instance_config.listen.name}_{self.num_clients}" + LOG.info( + "Connection from %s, connection initiated %s (SSL: %s)", + address, + sock_name, + self.ssl_context is not None, + ) + + events = selectors.EVENT_READ + context = {"instance_config": self.instance_config} + + conn = connection.Connection( + clientsocket, + name=sock_name, + address=address, + events=events, + context=context, + ) + + pg_conn = self._create_pg_connection(address, context) - conn = connection.Connection( - clientsocket, - name=sock_name, - address=address, - events=events, - context=context, + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.responses is not None + ): + pg_conn.interceptor = ResponseInterceptor( + self.instance_config.intercept.responses, self.plugins, context ) + pg_conn.redirect_conn = conn + + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.commands is not None + ): + conn.interceptor = CommandInterceptor( + self.instance_config.intercept.commands, self.plugins, context + ) + conn.redirect_conn = pg_conn - pg_conn = self._create_pg_connection(address, context) - - if ( - self.instance_config.intercept is not None - and self.instance_config.intercept.responses is not None - ): - pg_conn.interceptor = ResponseInterceptor( - self.instance_config.intercept.responses, self.plugins, context - ) - pg_conn.redirect_conn = conn - - if ( - self.instance_config.intercept is not None - and self.instance_config.intercept.commands is not None - ): - conn.interceptor = CommandInterceptor( - self.instance_config.intercept.commands, self.plugins, context - ) - conn.redirect_conn = pg_conn - - self._register_conn(conn) - self._register_conn(pg_conn) - - except ConnectionRefusedError: - LOG.debug("Connection refused in Postgres proxy server - instance not (yet) available") - except Exception as e: - LOG.warning("Error accepting connection in Postgres proxy: %s", e) + self._register_conn(conn) + self._register_conn(pg_conn) def _handle_ssl_negotiation( self, client_socket: socket.socket, ssl_context: ssl.SSLContext @@ -189,32 +194,27 @@ def _handle_ssl_negotiation( Returns the SSL-wrapped socket if negotiation succeeds, or the original socket. """ - try: - # Peek at the first 8 bytes to check for SSLRequest - # Using MSG_PEEK so we don't consume the data if it's not SSLRequest - client_socket.setblocking(True) - data = client_socket.recv(8, socket.MSG_PEEK) - - if len(data) == 8: - length = int.from_bytes(data[:4], "big") - code = int.from_bytes(data[4:8], "big") - - if length == 8 and code == 80877103: # SSLRequest code - # Consume the SSLRequest - client_socket.recv(8) - # Send 'S' to indicate SSL is supported - client_socket.send(b"S") - # Wrap socket with SSL - ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) - LOG.debug("SSL handshake completed for PostgreSQL connection") - return ssl_socket - - # Not an SSLRequest, return original socket - return client_socket - except Exception as e: - LOG.debug("Error during PostgreSQL SSL negotiation: %s", e) - return client_socket + # Peek at the first 8 bytes to check for SSLRequest + # Using MSG_PEEK so we don't consume the data if it's not SSLRequest + data = client_socket.recv(8, socket.MSG_PEEK) + + if len(data) == 8: + length = int.from_bytes(data[:4], "big") + code = int.from_bytes(data[4:8], "big") + + if length == 8 and code == 80877103: # SSLRequest code + # Consume the SSLRequest + client_socket.recv(8) + # Send 'S' to indicate SSL is supported + client_socket.send(b"S") + # Wrap socket with SSL + ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) + LOG.debug("SSL handshake completed for PostgreSQL connection") + return ssl_socket + + # Not an SSLRequest, return original socket + return client_socket def service_connection(self, key: SelectorKeyProxy, mask): """