Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 83 additions & 17 deletions postgresql_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
interceptors.py - intercepting for modification
'''

from __future__ import annotations

import logging
import selectors
import socket
import ssl
from types import ModuleType

from postgresql_proxy import connection, config_schema as cfg
from postgresql_proxy.interceptors import ResponseInterceptor, CommandInterceptor

Expand All @@ -41,14 +46,22 @@ class SelectorKeyProxy(selectors.SelectorKey):


class Proxy(object):
def __init__(self, instance_config, plugins, debug=False):
def __init__(
self,
instance_config: cfg.InstanceSettings,
plugins: dict[str, ModuleType],
debug: bool = False,
ssl_context: ssl.SSLContext | None = None,
) -> None:
self.plugins = plugins
self.num_clients = 0
self.instance_config = instance_config
self.connections = []
self.selector = selectors.DefaultSelector()
self.running = True
self.sock = None
self.ssl_context = ssl_context

# this is used to track leftover sockets
self._debug = debug
if self._debug:
Expand Down Expand Up @@ -114,42 +127,95 @@ def accept_wrapper(self, sock: socket.socket):
:param sock: the client socket
:return:
"""
clientsocket, address = sock.accept() # Should be ready to

# Accept the raw connection
clientsocket, address = sock.accept()

# Check if SSL is enabled for this proxy
if self.ssl_context:
# Handle SSL negotiation - must happen before setblocking(False)
clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context)

clientsocket.setblocking(False)
self.num_clients += 1
sock_name = '{}_{}'.format(self.instance_config.listen.name, self.num_clients)
LOG.info("connection from %s, connection initiated %s", address, sock_name)
events = selectors.EVENT_READ
sock_name = f"{self.instance_config.listen.name}_{self.num_clients}"
LOG.info(
"Connection from %s, connection initiated %s (SSL: %s)",
address,
sock_name,
self.ssl_context is not None,
)

# Context dictionary, for sharing state data, connection details, which might be useful for interceptors
context = {
'instance_config': self.instance_config
}
events = selectors.EVENT_READ
context = {"instance_config": self.instance_config}

# create a Connection object, representing the relation between a proxied client to postgres
conn = connection.Connection(
clientsocket,
name=sock_name,
address=address,
events=events,
context=context
context=context,
)

# create the connection to Postgres
pg_conn = self._create_pg_connection(address, context)

if self.instance_config.intercept is not None and self.instance_config.intercept.responses is not None:
pg_conn.interceptor = ResponseInterceptor(self.instance_config.intercept.responses, self.plugins, context)
if (
self.instance_config.intercept is not None
and self.instance_config.intercept.responses is not None
):
pg_conn.interceptor = ResponseInterceptor(
self.instance_config.intercept.responses, self.plugins, context
)
pg_conn.redirect_conn = conn

if self.instance_config.intercept is not None and self.instance_config.intercept.commands is not None:
conn.interceptor = CommandInterceptor(self.instance_config.intercept.commands, self.plugins, context)
if (
self.instance_config.intercept is not None
and self.instance_config.intercept.commands is not None
):
conn.interceptor = CommandInterceptor(
self.instance_config.intercept.commands, self.plugins, context
)
conn.redirect_conn = pg_conn

# Register both connections to be watched by the selector
self._register_conn(conn)
self._register_conn(pg_conn)

def _handle_ssl_negotiation(
self, client_socket: socket.socket, ssl_context: ssl.SSLContext
) -> socket.socket:
"""
Handle PostgreSQL SSL negotiation on an accepted socket.

PostgreSQL SSL flow:
1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4)
2. Server responds 'S' (SSL supported) or 'N' (not supported)
3. If 'S', TLS handshake follows
4. After TLS, normal PostgreSQL protocol begins

Returns the SSL-wrapped socket if negotiation succeeds, or the original socket.
"""

# Peek at the first 8 bytes to check for SSLRequest
# Using MSG_PEEK so we don't consume the data if it's not SSLRequest
data = client_socket.recv(8, socket.MSG_PEEK)

if len(data) == 8:
length = int.from_bytes(data[:4], "big")
code = int.from_bytes(data[4:8], "big")

if length == 8 and code == 80877103: # SSLRequest code
# Consume the SSLRequest
client_socket.recv(8)
# Send 'S' to indicate SSL is supported
client_socket.send(b"S")
# Wrap socket with SSL
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
LOG.debug("SSL handshake completed for PostgreSQL connection")
return ssl_socket

# Not an SSLRequest, return original socket
return client_socket

def service_connection(self, key: SelectorKeyProxy, mask):
"""
This method proxies the messages between socket. It will use properties of the Connection object to
Expand Down