diff --git a/postgresql_proxy/proxy.py b/postgresql_proxy/proxy.py index 1417db9..ef4ad19 100644 --- a/postgresql_proxy/proxy.py +++ b/postgresql_proxy/proxy.py @@ -24,9 +24,14 @@ interceptors.py - intercepting for modification ''' +from __future__ import annotations + import logging import selectors import socket +import ssl +from types import ModuleType + from postgresql_proxy import connection, config_schema as cfg from postgresql_proxy.interceptors import ResponseInterceptor, CommandInterceptor @@ -41,7 +46,13 @@ class SelectorKeyProxy(selectors.SelectorKey): class Proxy(object): - def __init__(self, instance_config, plugins, debug=False): + def __init__( + self, + instance_config: cfg.InstanceSettings, + plugins: dict[str, ModuleType], + debug: bool = False, + ssl_context: ssl.SSLContext | None = None, + ) -> None: self.plugins = plugins self.num_clients = 0 self.instance_config = instance_config @@ -49,6 +60,8 @@ def __init__(self, instance_config, plugins, debug=False): self.selector = selectors.DefaultSelector() self.running = True self.sock = None + self.ssl_context = ssl_context + # this is used to track leftover sockets self._debug = debug if self._debug: @@ -114,42 +127,95 @@ def accept_wrapper(self, sock: socket.socket): :param sock: the client socket :return: """ - clientsocket, address = sock.accept() # Should be ready to + + # Accept the raw connection + clientsocket, address = sock.accept() + + # Check if SSL is enabled for this proxy + if self.ssl_context: + # Handle SSL negotiation - must happen before setblocking(False) + clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context) + clientsocket.setblocking(False) self.num_clients += 1 - sock_name = '{}_{}'.format(self.instance_config.listen.name, self.num_clients) - LOG.info("connection from %s, connection initiated %s", address, sock_name) - events = selectors.EVENT_READ + sock_name = f"{self.instance_config.listen.name}_{self.num_clients}" + LOG.info( + "Connection from %s, connection initiated %s (SSL: %s)", + address, + sock_name, + self.ssl_context is not None, + ) - # Context dictionary, for sharing state data, connection details, which might be useful for interceptors - context = { - 'instance_config': self.instance_config - } + events = selectors.EVENT_READ + context = {"instance_config": self.instance_config} - # create a Connection object, representing the relation between a proxied client to postgres conn = connection.Connection( clientsocket, name=sock_name, address=address, events=events, - context=context + context=context, ) - # create the connection to Postgres pg_conn = self._create_pg_connection(address, context) - if self.instance_config.intercept is not None and self.instance_config.intercept.responses is not None: - pg_conn.interceptor = ResponseInterceptor(self.instance_config.intercept.responses, self.plugins, context) + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.responses is not None + ): + pg_conn.interceptor = ResponseInterceptor( + self.instance_config.intercept.responses, self.plugins, context + ) pg_conn.redirect_conn = conn - if self.instance_config.intercept is not None and self.instance_config.intercept.commands is not None: - conn.interceptor = CommandInterceptor(self.instance_config.intercept.commands, self.plugins, context) + if ( + self.instance_config.intercept is not None + and self.instance_config.intercept.commands is not None + ): + conn.interceptor = CommandInterceptor( + self.instance_config.intercept.commands, self.plugins, context + ) conn.redirect_conn = pg_conn - # Register both connections to be watched by the selector self._register_conn(conn) self._register_conn(pg_conn) + def _handle_ssl_negotiation( + self, client_socket: socket.socket, ssl_context: ssl.SSLContext + ) -> socket.socket: + """ + Handle PostgreSQL SSL negotiation on an accepted socket. + + PostgreSQL SSL flow: + 1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4) + 2. Server responds 'S' (SSL supported) or 'N' (not supported) + 3. If 'S', TLS handshake follows + 4. After TLS, normal PostgreSQL protocol begins + + Returns the SSL-wrapped socket if negotiation succeeds, or the original socket. + """ + + # Peek at the first 8 bytes to check for SSLRequest + # Using MSG_PEEK so we don't consume the data if it's not SSLRequest + data = client_socket.recv(8, socket.MSG_PEEK) + + if len(data) == 8: + length = int.from_bytes(data[:4], "big") + code = int.from_bytes(data[4:8], "big") + + if length == 8 and code == 80877103: # SSLRequest code + # Consume the SSLRequest + client_socket.recv(8) + # Send 'S' to indicate SSL is supported + client_socket.send(b"S") + # Wrap socket with SSL + ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) + LOG.debug("SSL handshake completed for PostgreSQL connection") + return ssl_socket + + # Not an SSLRequest, return original socket + return client_socket + def service_connection(self, key: SelectorKeyProxy, mask): """ This method proxies the messages between socket. It will use properties of the Connection object to