From 767bd166000aad4e56d97e38611eac6b63f59b1d Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Wed, 15 Apr 2026 16:59:53 -0400 Subject: [PATCH 1/2] feat: support standalone install with embedded postgres --- pyproject.toml | 4 + testgen/__main__.py | 142 +++++++++++++++- testgen/commands/run_launch_db_config.py | 15 +- testgen/commands/run_quick_start.py | 15 +- testgen/common/database/database_service.py | 28 +++- .../flavor/postgresql_flavor_service.py | 9 + testgen/common/logs.py | 22 +-- testgen/common/models/__init__.py | 14 +- testgen/common/standalone_postgres.py | 141 ++++++++++++++++ testgen/settings.py | 158 ++++++++++-------- testgen/ui/app.py | 13 +- .../components/frontend/js/display_utils.js | 2 +- testgen/ui/static/js/display_utils.js | 2 +- 13 files changed, 458 insertions(+), 107 deletions(-) create mode 100644 testgen/common/standalone_postgres.py diff --git a/pyproject.toml b/pyproject.toml index 8248d498..bb6e7b49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,10 @@ dependencies = [ ] [project.optional-dependencies] +standalone = [ + "pixeltable-pgserver>=0.5.1", +] + dev = [ "invoke==2.2.0", "ruff==0.4.1", diff --git a/testgen/__main__.py b/testgen/__main__.py index 29c6d3b0..0b9005a1 100644 --- a/testgen/__main__.py +++ b/testgen/__main__.py @@ -1,10 +1,16 @@ +import base64 +import importlib import logging import os +import platform +import secrets import signal import subprocess import sys from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta +from importlib.metadata import version as pkg_version +from pathlib import Path import click from click.core import Context @@ -42,6 +48,13 @@ get_tg_schema, version_service, ) +from testgen.common.standalone_postgres import ( + STANDALONE_URI_ENV_VAR, + get_home_dir as get_testgen_home, + get_server_uri, + is_standalone_mode, + start_server as start_standalone_postgres, +) from testgen.common.models import with_database_session from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.settings import PersistedSetting @@ -99,19 +112,23 @@ def invoke(self, ctx: Context): ) @click.pass_context def cli(ctx: Context, verbose: bool): + if is_standalone_mode(): + start_standalone_postgres() + if verbose: configure_logging(level=logging.DEBUG) else: configure_logging(level=logging.INFO) ctx.obj = Configuration(verbose=verbose) - status_ok, message = docker_service.check_basic_configuration() - if not status_ok: - click.secho(message, fg="red") - sys.exit(1) + if not is_standalone_mode() and ctx.invoked_subcommand != "standalone-setup": + status_ok, message = docker_service.check_basic_configuration() + if not status_ok: + click.secho(message, fg="red") + sys.exit(1) if ( - ctx.invoked_subcommand not in ["run-app", "ui", "setup-system-db", "upgrade-system-version", "quick-start"] + ctx.invoked_subcommand not in ["run-app", "ui", "setup-system-db", "upgrade-system-version", "quick-start", "standalone-setup"] and not is_db_revision_up_to_date() ): click.secho("The system database schema is outdated. Automatically running the following command:", fg="red") @@ -472,6 +489,110 @@ def quick_start( click.echo("Quick start has successfully finished.") +@cli.command("standalone-setup", help="Set up TestGen for standalone use with embedded PostgreSQL (no Docker required).") +@click.option("--username", prompt="Admin username", default="admin", help="Username for the TestGen web UI.") +@click.option( + "--password", prompt="Admin password", hide_input=True, confirmation_prompt=True, + default="testgen", help="Password for the TestGen web UI.", +) +def setup_standalone(username: str, password: str): + config_dir = get_testgen_home() + config_path = config_dir / "config.env" + + if config_path.exists(): + if not click.confirm(f"Config already exists at {config_path}. Overwrite?"): + click.echo("Aborted.") + return + + # Generate secrets (same approach as dk-installer) + def generate_secret(length: int = 12) -> str: + alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + return "".join(secrets.choice(alphabet) for _ in range(length)) + + jwt_key = base64.b64encode(secrets.token_bytes(32)).decode() + decrypt_salt = generate_secret() + decrypt_password = generate_secret() + log_dir = str(config_dir / "log") + + config_dir.mkdir(parents=True, exist_ok=True) + + config_lines = [ + "# TestGen standalone configuration", + "# Generated by: testgen standalone-setup", + "", + "# Standalone mode (embedded PostgreSQL)", + "TG_STANDALONE_MODE=yes", + "", + "# UI credentials", + f"TESTGEN_USERNAME={username}", + f"TESTGEN_PASSWORD={password}", + "", + "# Encryption keys", + f"TG_DECRYPT_SALT={decrypt_salt}", + f"TG_DECRYPT_PASSWORD={decrypt_password}", + f"TG_JWT_HASHING_KEY={jwt_key}", + "", + "# Logging", + f"TESTGEN_LOG_FILE_PATH={log_dir}", + "", + "# Analytics", + "TG_ANALYTICS=yes", + "", + "# Trust target database certificates (for SQL Server, etc.)", + "TG_TARGET_DB_TRUST_SERVER_CERTIFICATE=yes", + "TG_EXPORT_TO_OBSERVABILITY_VERIFY_SSL=no", + ] + config_path.write_text("\n".join(config_lines) + "\n") + click.echo(f"Config written to {config_path}") + + # Reload settings — the module was already evaluated at import time + # before the config file existed. Reloading re-reads the new file + # and re-evaluates all module-level variables. + importlib.reload(settings) + + # Patch Streamlit to support editable-install component resolution + click.echo("Patching Streamlit...") + from testgen.ui.scripts.patch_streamlit import patch as patch_streamlit + patch_streamlit(dev=True) + + # Start embedded PostgreSQL (standalone mode is now active via config) + start_standalone_postgres() + + # Initialize the database + click.echo("Initializing database...") + run_launch_db_config(delete_db=False) + + # Send analytics event for pip install tracking + try: + from testgen.common.mixpanel_service import MixpanelService + + mp = MixpanelService() + mp.send_event( + "standalone_setup", + username=username, + install_type="standalone", + version=pkg_version("dataops-testgen"), + python_info=f"{platform.python_implementation()} {platform.python_version()}", + **{"$os": platform.system()}, + os_version=platform.release(), + os_arch=platform.machine(), + ) + except Exception: # noqa: S110 + pass + + click.echo("") + click.echo(click.style("TestGen is ready!", fg="green", bold=True)) + click.echo("") + click.echo(" To load demo data (optional):") + click.echo(" testgen quick-start") + click.echo("") + click.echo(" Start the application:") + click.echo(" testgen run-app") + click.echo("") + click.echo(" Then open http://localhost:8501 in your browser.") + click.echo(f" Log in with username: {username}") + + @cli.command("setup-system-db", help="Use to initialize the TestGen system database.") @click.option( "--delete-db", @@ -728,6 +849,15 @@ def init_ui(): init_ui() app_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui/app.py") + + # In standalone mode, pass the pgserver URI to the Streamlit subprocess + # so it can connect without acquiring the pgserver file lock. + child_env = {**os.environ, "TG_JOB_SOURCE": "UI"} + if is_standalone_mode(): + server_uri = get_server_uri() + if server_uri: + child_env = {**os.environ, "TG_JOB_SOURCE": "UI", STANDALONE_URI_ENV_VAR: server_uri} + process= subprocess.Popen( [ # noqa: S607 "streamlit", @@ -742,7 +872,7 @@ def init_ui(): "--", f"{'--debug' if settings.IS_DEBUG else ''}", ], - env={**os.environ, "TG_JOB_SOURCE": "UI"} + env=child_env, ) def term_ui(signum, _): LOG.info(f"Sending termination signal {signum} to Testgen UI") diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py index 0d926fbe..41115afd 100644 --- a/testgen/commands/run_launch_db_config.py +++ b/testgen/commands/run_launch_db_config.py @@ -4,6 +4,7 @@ from testgen import settings from testgen.common import create_database, execute_db_queries from testgen.common.credentials import get_tg_db, get_tg_schema +from testgen.common.standalone_postgres import get_home_dir, is_standalone_mode from testgen.common.database.database_service import get_queries_for_command from testgen.common.encrypt import EncryptText, encrypt_ui_password from testgen.common.models import with_database_session @@ -22,6 +23,14 @@ def _get_latest_revision_number(): def _get_params_mapping() -> dict: ui_user_encrypted_password = encrypt_ui_password(settings.PASSWORD) + project_host = settings.PROJECT_DATABASE_HOST + project_user = settings.PROJECT_DATABASE_USER + project_password = settings.PROJECT_DATABASE_PASSWORD + if is_standalone_mode(): + project_host = str(get_home_dir() / "pgdata") + project_user = "postgres" + project_password = "" + return { "UI_USER_NAME": settings.USERNAME, "UI_USER_USERNAME": settings.USERNAME, @@ -33,10 +42,10 @@ def _get_params_mapping() -> dict: "SQL_FLAVOR": settings.PROJECT_SQL_FLAVOR, "PROJECT_NAME": settings.PROJECT_NAME, "PROJECT_DB": settings.PROJECT_DATABASE_NAME, - "PROJECT_USER": settings.PROJECT_DATABASE_USER, + "PROJECT_USER": project_user, "PROJECT_PORT": settings.PROJECT_DATABASE_PORT, - "PROJECT_HOST": settings.PROJECT_DATABASE_HOST, - "PROJECT_PW_ENCRYPTED": EncryptText(settings.PROJECT_DATABASE_PASSWORD), + "PROJECT_HOST": project_host, + "PROJECT_PW_ENCRYPTED": EncryptText(project_password), "PROJECT_HTTP_PATH": "", "PROJECT_SERVICE_ACCOUNT_KEY": "", "PROJECT_SCHEMA": settings.PROJECT_DATABASE_SCHEMA, diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index f1885c69..adb9a36f 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -8,6 +8,7 @@ from testgen import settings from testgen.commands.run_launch_db_config import get_app_db_params_mapping, run_launch_db_config +from testgen.common.standalone_postgres import get_home_dir, is_standalone_mode from testgen.commands.test_generation import run_monitor_generation from testgen.common.credentials import get_tg_schema from testgen.common.database.database_service import ( @@ -93,14 +94,22 @@ def _prepare_connection_to_target_database(params_mapping): def _get_settings_params_mapping() -> dict: + host = settings.PROJECT_DATABASE_HOST + admin_user = settings.DATABASE_ADMIN_USER + admin_password = settings.DATABASE_ADMIN_PASSWORD + if is_standalone_mode(): + host = str(get_home_dir() / "pgdata") + admin_user = "postgres" + admin_password = "" + return { - "TESTGEN_ADMIN_USER": settings.DATABASE_ADMIN_USER, - "TESTGEN_ADMIN_PASSWORD": settings.DATABASE_ADMIN_PASSWORD, + "TESTGEN_ADMIN_USER": admin_user, + "TESTGEN_ADMIN_PASSWORD": admin_password, "SCHEMA_NAME": get_tg_schema(), "PROJECT_DB": settings.PROJECT_DATABASE_NAME, "PROJECT_SCHEMA": settings.PROJECT_DATABASE_SCHEMA, "PROJECT_KEY": settings.PROJECT_KEY, - "PROJECT_DB_HOST": settings.PROJECT_DATABASE_HOST, + "PROJECT_DB_HOST": host, "PROJECT_DB_PORT": settings.PROJECT_DATABASE_PORT, "SQL_FLAVOR": settings.PROJECT_SQL_FLAVOR, } diff --git a/testgen/common/database/database_service.py b/testgen/common/database/database_service.py index 0e338318..dae77d6d 100644 --- a/testgen/common/database/database_service.py +++ b/testgen/common/database/database_service.py @@ -32,6 +32,7 @@ SQLFlavor, resolve_connection_params, ) +from testgen.common.standalone_postgres import get_connection_string as get_standalone_connection_string, is_standalone_mode from testgen.common.read_file import get_template_files from testgen.utils import get_exception_message @@ -370,16 +371,27 @@ def _init_app_db_connection( engine = engine_cache.app_db if not engine: - user = user_override if is_admin else get_tg_username() - password = password_override if (is_admin or password_override is not None) else get_tg_password() + if is_standalone_mode(): + connection_string = get_standalone_connection_string(database_name) + else: + user = user_override if is_admin else get_tg_username() + password = password_override if (is_admin or password_override is not None) else get_tg_password() - # STANDARD FORMAT: flavor://username:password@host:port/database - connection_string = ( - f"postgresql://{user}:{quote_plus(password)}@{get_tg_host()}:{get_tg_port()}/{database_name}" - ) + # STANDARD FORMAT: flavor://username:password@host:port/database + connection_string = ( + f"postgresql://{user}:{quote_plus(password)}@{get_tg_host()}:{get_tg_port()}/{database_name}" + ) try: - engine: Engine = create_engine(connection_string, connect_args={"connect_timeout": 3600}) - engine_cache.app_db = engine + engine: Engine = create_engine( + connection_string, + connect_args={ + "connect_timeout": 3600, + # Force UTC so TIMESTAMP-without-tz inserts aren't silently shifted. + "options": "-c TimeZone=UTC", + }, + ) + if user_type == "normal": + engine_cache.app_db = engine except SQLAlchemyError as e: raise ValueError(f"Failed to create engine for App database '{database_name}' (User type = {user_type})") from e diff --git a/testgen/common/database/flavor/postgresql_flavor_service.py b/testgen/common/database/flavor/postgresql_flavor_service.py index 65c10dd4..99f968c6 100644 --- a/testgen/common/database/flavor/postgresql_flavor_service.py +++ b/testgen/common/database/flavor/postgresql_flavor_service.py @@ -1,6 +1,15 @@ +from urllib.parse import quote_plus + +from testgen.common.database.flavor.flavor_service import ResolvedConnectionParams from testgen.common.database.flavor.redshift_flavor_service import RedshiftFlavorService class PostgresqlFlavorService(RedshiftFlavorService): escaped_underscore = "\\_" + + def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> str: + if params.host.startswith("/"): + # Unix socket path — use query-param format for psycopg2 + return f"{self.url_scheme}://{params.username}:{quote_plus(params.password)}@/{params.dbname}?host={params.host}" + return super().get_connection_string_from_fields(params) diff --git a/testgen/common/logs.py b/testgen/common/logs.py index 6c42b353..566de050 100644 --- a/testgen/common/logs.py +++ b/testgen/common/logs.py @@ -29,16 +29,18 @@ def configure_logging( logger.addHandler(console_handler) if settings.LOG_TO_FILE: - os.makedirs(settings.LOG_FILE_PATH, exist_ok=True) - - file_handler = ConcurrentTimedRotatingFileHandler( - get_log_full_path(), - when="MIDNIGHT", - interval=1, - backupCount=int(settings.LOG_FILE_MAX_QTY), - ) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + try: + os.makedirs(settings.LOG_FILE_PATH, exist_ok=True) + file_handler = ConcurrentTimedRotatingFileHandler( + get_log_full_path(), + when="MIDNIGHT", + interval=1, + backupCount=int(settings.LOG_FILE_MAX_QTY), + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + except OSError: + logger.warning("Cannot write logs to %s — file logging disabled", settings.LOG_FILE_PATH) def get_log_full_path() -> str: diff --git a/testgen/common/models/__init__.py b/testgen/common/models/__init__.py index 6e2b581c..898090dd 100644 --- a/testgen/common/models/__init__.py +++ b/testgen/common/models/__init__.py @@ -5,9 +5,7 @@ import urllib.parse from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session as SQLAlchemySession -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session as SQLAlchemySession, sessionmaker from testgen import settings @@ -19,11 +17,17 @@ echo=False, connect_args={ "application_name": platform.node(), - "options": f"-csearch_path={settings.DATABASE_SCHEMA}", + # TimeZone=UTC so TIMESTAMP (no-tz) columns store aware UTC datetimes as-is. + # Without this, pgserver inherits the OS TZ and silently shifts + # timestamps on insert, which make_json_safe then re-reads as UTC. + "options": f"-csearch_path={settings.DATABASE_SCHEMA} -c TimeZone=UTC", }, ) -Base = declarative_base() +class Base(DeclarativeBase): + # Allow legacy Column() + type-hint patterns without Mapped[]. + # Can be removed once all models use Mapped[] annotations. + __allow_unmapped__ = True Session = sessionmaker( engine, diff --git a/testgen/common/standalone_postgres.py b/testgen/common/standalone_postgres.py new file mode 100644 index 00000000..ecfcb8ce --- /dev/null +++ b/testgen/common/standalone_postgres.py @@ -0,0 +1,141 @@ +"""Embedded PostgreSQL server for standalone (pip-only) installations. + +When TestGen is installed with `pip install testgen[standalone]`, this module +manages an embedded PostgreSQL instance via `pixeltable-pgserver`. The server +stores its data under a configurable directory and runs as the current OS +user — no Docker, no system Postgres, no root access required. +""" + +import atexit +import logging +import os +import platform +from pathlib import Path +from urllib.parse import urlparse, urlunparse + +from testgen import settings + +LOG = logging.getLogger("testgen") + +_server = None + +STANDALONE_MODE_ENV_VAR = "TG_STANDALONE_MODE" +HOME_DIR_ENV_VAR = "TG_TESTGEN_HOME" +STANDALONE_URI_ENV_VAR = "_TG_STANDALONE_URI" + + +def get_home_dir() -> Path: + env_dir = os.getenv(HOME_DIR_ENV_VAR) + return Path(env_dir) if env_dir else Path.home() / ".testgen" + + +def is_standalone_mode() -> bool: + return settings.getenv(STANDALONE_MODE_ENV_VAR, "no").lower() in ("yes", "true", "1") + + +def start_server(data_dir: Path | None = None) -> None: + """Start the embedded PostgreSQL server. + + The server persists data across restarts in *data_dir* (default + ``$TG_TESTGEN_HOME/pgdata`` or ``~/.testgen/pgdata``). + + Calling this multiple times is safe — the second call is a no-op + if the server is already running. + """ + global _server + + if _server is not None: + return + + try: + import pixeltable_pgserver as pgserver + except ImportError: + raise RuntimeError( + "Standalone mode requires the 'standalone' extra. " + "Install with: pip install testgen[standalone]" + ) from None + + if data_dir is None: + data_dir = get_home_dir() / "pgdata" + data_dir.mkdir(parents=True, exist_ok=True) + + LOG.info("Starting embedded PostgreSQL (data: %s) ...", data_dir) + _server = pgserver.get_server(data_dir) + LOG.info("Embedded PostgreSQL ready: %s", _server.get_uri()) + + _reinitialize_orm_engine() + atexit.register(stop_server) + + +def get_server_uri() -> str | None: + """Return the pgserver URI if the server is running in this process, else ``None``.""" + return _server.get_uri() if _server is not None else None + + +def ensure_standalone_setup(server_uri: str) -> None: + """Reinitialize the ORM engine to connect to an already-running embedded instance. + + Called by child processes (e.g. Streamlit) that receive the URI from + their parent — they should NOT start pgserver themselves. + """ + if _server is not None: + return + _reinitialize_orm_engine(server_uri) + + +def _reinitialize_orm_engine(base_uri: str | None = None) -> None: + """Recreate the ORM engine to use the embedded Unix socket URI. + + ``models/__init__`` creates its engine at import time from + ``settings.DATABASE_*`` (TCP). After the embedded server starts we + must replace that engine so the ORM connects via Unix socket. + """ + from sqlalchemy import create_engine + from testgen.common import models + + uri = _build_connection_string(settings.DATABASE_NAME, base_uri) + models.engine.dispose() + models.engine = create_engine( + url=uri, + echo=False, + connect_args={ + "application_name": platform.node(), + # Keep in sync with models/__init__.py — UTC avoids silent tz shifts on TIMESTAMP inserts. + "options": f"-csearch_path={settings.DATABASE_SCHEMA} -c TimeZone=UTC", + }, + ) + models.Session.configure(bind=models.engine) + + +def stop_server() -> None: + """Stop the embedded PostgreSQL server if running.""" + global _server + if _server is not None: + LOG.info("Stopping embedded PostgreSQL ...") + _server.cleanup() + _server = None + + +def get_connection_string(database_name: str) -> str: + """Return a SQLAlchemy connection string for the given database on the embedded server.""" + return _build_connection_string(database_name) + + +def _build_connection_string(database_name: str, base_uri: str | None = None) -> str: + """Build a Unix socket connection string, replacing the database in the path. + + Resolution order for the base URI: + 1. Caller-provided ``base_uri``. + 2. ``_server.get_uri()`` when pgserver is running in this process (parent CLI). + 3. ``STANDALONE_URI_ENV_VAR`` env var — set by the parent for child processes + (Streamlit UI, scheduler) that share the already-running instance. + """ + if base_uri is None: + if _server is not None: + base_uri = _server.get_uri() + else: + base_uri = os.environ.get(STANDALONE_URI_ENV_VAR) + if not base_uri: + raise RuntimeError("Embedded PostgreSQL server is not running") + parsed = urlparse(base_uri) + return urlunparse(parsed._replace(path=f"/{database_name}")) diff --git a/testgen/settings.py b/testgen/settings.py index 8d2b4512..351bd4db 100644 --- a/testgen/settings.py +++ b/testgen/settings.py @@ -1,13 +1,39 @@ import os import typing +from pathlib import Path -IS_DEBUG_LOG_LEVEL: bool = os.getenv("TESTGEN_DEBUG_LOG_LEVEL", "no").lower() in ("yes", "true") + +def _load_config() -> dict[str, str]: + """Load ``$TG_TESTGEN_HOME/config.env`` (default ``~/.testgen/config.env``).""" + home = Path(os.environ["TG_TESTGEN_HOME"]) if "TG_TESTGEN_HOME" in os.environ else Path.home() / ".testgen" + config_path = home / "config.env" + config: dict[str, str] = {} + if config_path.is_file(): + for line in config_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + key, _, value = line.partition("=") + if key: + config[key] = value + return config + + +_config = _load_config() + + +def getenv(key: str, default: str | None = None) -> str | None: + """Look up *key* in environment first, then config file, then *default*.""" + return os.environ.get(key) or _config.get(key) or default + + +IS_DEBUG_LOG_LEVEL: bool = getenv("TESTGEN_DEBUG_LOG_LEVEL", "no").lower() in ("yes", "true") """ When set, logs will be at debug level. defaults to: `no` """ -IS_DEBUG: bool = os.getenv("TESTGEN_DEBUG", "no").lower() in ("yes", "true") +IS_DEBUG: bool = getenv("TESTGEN_DEBUG", "no").lower() in ("yes", "true") """ When True invalidates the cache with the bootstrapped application causing the changes to the routing and plugins to take effect on every @@ -17,24 +43,24 @@ defaults to: `True` """ -LOG_TO_FILE: bool = os.getenv("TESTGEN_LOG_TO_FILE", "yes").lower() in ("yes", "true") +LOG_TO_FILE: bool = getenv("TESTGEN_LOG_TO_FILE", "yes").lower() in ("yes", "true") """ When set, rotating file logs will be generated. defaults to: `True` """ -LOG_FILE_PATH: str = os.getenv("TESTGEN_LOG_FILE_PATH", "/var/lib/testgen/log") +LOG_FILE_PATH: str = getenv("TESTGEN_LOG_FILE_PATH", "/var/lib/testgen/log") """ When set, rotating file logs will be generated under this path. """ -LOG_FILE_MAX_QTY: str = os.getenv("TESTGEN_LOG_FILE_MAX_QTY", "90") +LOG_FILE_MAX_QTY: str = getenv("TESTGEN_LOG_FILE_MAX_QTY", "90") """ Maximum log files to keep, defaults to 90 days (one file per day). """ -APP_ENCRYPTION_SALT: str = os.getenv("TG_DECRYPT_SALT") +APP_ENCRYPTION_SALT: str = getenv("TG_DECRYPT_SALT") """ Salt used to encrypt and decrypt user secrets. Only allows ascii characters. @@ -44,7 +70,7 @@ from env variable: `TG_DECRYPT_SALT` """ -APP_ENCRYPTION_SECRET: str = os.getenv("TG_DECRYPT_PASSWORD") +APP_ENCRYPTION_SECRET: str = getenv("TG_DECRYPT_PASSWORD") """ Secret passcode used in combination with `APP_ENCRYPTION_SALT` to encrypt and decrypt user secrets. Only allows ascii characters. @@ -52,21 +78,21 @@ from env variable: `TG_DECRYPT_PASSWORD` """ -USERNAME: str = os.getenv("TESTGEN_USERNAME") +USERNAME: str = getenv("TESTGEN_USERNAME") """ Username to log into the web application from env variable: `TESTGEN_USERNAME` """ -PASSWORD: str = os.getenv("TESTGEN_PASSWORD") +PASSWORD: str = getenv("TESTGEN_PASSWORD") """ Password to log into the web application from env variable: `TESTGEN_PASSWORD` """ -DATABASE_USER: str = os.getenv("TG_METADATA_DB_USER", USERNAME) +DATABASE_USER: str = getenv("TG_METADATA_DB_USER", USERNAME) """ User to connect to the testgen application postgres database. @@ -74,7 +100,7 @@ defaults to: `environ[USERNAME]` """ -DATABASE_PASSWORD: str = os.getenv("TG_METADATA_DB_PASSWORD", PASSWORD) +DATABASE_PASSWORD: str = getenv("TG_METADATA_DB_PASSWORD", PASSWORD) """ Password to connect to the testgen application postgres database. @@ -82,7 +108,7 @@ defaults to: `environ[PASSWORD]` """ -DATABASE_ADMIN_USER: str = os.getenv("DATABASE_ADMIN_USER", DATABASE_USER) +DATABASE_ADMIN_USER: str = getenv("DATABASE_ADMIN_USER", DATABASE_USER) """ User with admin privileges in the testgen application postgres database used to create roles, users, database and schema. Required if the user @@ -92,7 +118,7 @@ defaults to: `environ[DATABASE_USER]` """ -DATABASE_ADMIN_PASSWORD: str = os.getenv("DATABASE_ADMIN_PASSWORD", DATABASE_PASSWORD) +DATABASE_ADMIN_PASSWORD: str = getenv("DATABASE_ADMIN_PASSWORD", DATABASE_PASSWORD) """ Password for the admin user to connect to the testgen application postgres database. @@ -101,7 +127,7 @@ defaults to: `environ[DATABASE_PASSWORD]` """ -DATABASE_EXECUTE_USER: str = os.getenv("DATABASE_EXECUTE_USER", "testgen_execute") +DATABASE_EXECUTE_USER: str = getenv("DATABASE_EXECUTE_USER", "testgen_execute") """ User to be created into the testgen application postgres database. Will be granted: @@ -113,7 +139,7 @@ defaults to: `testgen_execute` """ -DATABASE_REPORT_USER: str = os.getenv("DATABASE_REPORT_USER", "testgen_report") +DATABASE_REPORT_USER: str = getenv("DATABASE_REPORT_USER", "testgen_report") """ User to be created into the testgen application postgres database. Will be granted read_only access to all tables. @@ -122,7 +148,7 @@ defaults to: `testgen_report` """ -DATABASE_HOST: str = os.getenv("TG_METADATA_DB_HOST", "localhost") +DATABASE_HOST: str = getenv("TG_METADATA_DB_HOST", "localhost") """ Hostname where the testgen application postgres database is running in. @@ -130,7 +156,7 @@ defaults to: `localhost` """ -DATABASE_PORT: str = os.getenv("TG_METADATA_DB_PORT", "5432") +DATABASE_PORT: str = getenv("TG_METADATA_DB_PORT", "5432") """ Port at which the testgen application postgres database is exposed by the host. @@ -139,7 +165,7 @@ defaults to: `5432` """ -DATABASE_NAME: str = os.getenv("TG_METADATA_DB_NAME", "datakitchen") +DATABASE_NAME: str = getenv("TG_METADATA_DB_NAME", "datakitchen") """ Name of the database in postgres on which to store testgen metadata. @@ -147,7 +173,7 @@ defaults to: `datakitchen` """ -DATABASE_SCHEMA: str = os.getenv("TG_METADATA_DB_SCHEMA", "testgen") +DATABASE_SCHEMA: str = getenv("TG_METADATA_DB_SCHEMA", "testgen") """ Name of the schema inside the postgres database on which to store testgen metadata. @@ -156,7 +182,7 @@ defaults to: `testgen` """ -PROJECT_KEY: str = os.getenv("PROJECT_KEY", "DEFAULT") +PROJECT_KEY: str = getenv("PROJECT_KEY", "DEFAULT") """ Code used to uniquely identify the auto generated project. @@ -164,7 +190,7 @@ defaults to: `DEFAULT` """ -PROJECT_NAME: str = os.getenv("PROJECT_NAME", "Demo") +PROJECT_NAME: str = getenv("PROJECT_NAME", "Demo") """ Name to assign to the auto generated project. @@ -172,7 +198,7 @@ defaults to: `Demo` """ -PROJECT_SQL_FLAVOR: str = os.getenv("PROJECT_SQL_FLAVOR", "postgresql") +PROJECT_SQL_FLAVOR: str = getenv("PROJECT_SQL_FLAVOR", "postgresql") """ SQL flavor of the database the auto generated project will run tests against. @@ -187,7 +213,7 @@ defaults to: `postgresql` """ -PROJECT_CONNECTION_NAME: str = os.getenv("PROJECT_CONNECTION_NAME", "default") +PROJECT_CONNECTION_NAME: str = getenv("PROJECT_CONNECTION_NAME", "default") """ Name assigned to identify the connection to the project database. @@ -195,7 +221,7 @@ defaults to: `default` """ -PROJECT_CONNECTION_MAX_THREADS: int = int(os.getenv("PROJECT_CONNECTION_MAX_THREADS", "4")) +PROJECT_CONNECTION_MAX_THREADS: int = int(getenv("PROJECT_CONNECTION_MAX_THREADS", "4")) """ Maximum number of concurrent queries executed when fetching data from the project database. @@ -204,7 +230,7 @@ defaults to: `4` """ -PROJECT_CONNECTION_MAX_QUERY_CHAR: int = int(os.getenv("PROJECT_CONNECTION_MAX_QUERY_CHAR", "5000")) +PROJECT_CONNECTION_MAX_QUERY_CHAR: int = int(getenv("PROJECT_CONNECTION_MAX_QUERY_CHAR", "5000")) """ Determine how many tests are grouped together in a single query. Increase for better performance or decrease to better isolate test @@ -214,7 +240,7 @@ defaults to: `5000` """ -PROJECT_DATABASE_NAME: str = os.getenv("PROJECT_DATABASE_NAME", "demo_db") +PROJECT_DATABASE_NAME: str = getenv("PROJECT_DATABASE_NAME", "demo_db") """ Name of the database the auto generated project will run test against. @@ -223,7 +249,7 @@ defaults to: `demo_db` """ -PROJECT_DATABASE_SCHEMA: str = os.getenv("PROJECT_DATABASE_SCHEMA", "demo") +PROJECT_DATABASE_SCHEMA: str = getenv("PROJECT_DATABASE_SCHEMA", "demo") """ Name of the schema inside the project database the tests will be run against. @@ -232,7 +258,7 @@ defaults to: `demo` """ -PROJECT_DATABASE_USER: str = os.getenv("PROJECT_DATABASE_USER", DATABASE_USER) +PROJECT_DATABASE_USER: str = getenv("PROJECT_DATABASE_USER", DATABASE_USER) """ User to be used by the auto generated project to connect to the database under testing. @@ -241,7 +267,7 @@ defaults to: `environ[DATABASE_USER]` """ -PROJECT_DATABASE_PASSWORD: str = os.getenv("PROJECT_DATABASE_PASSWORD", DATABASE_PASSWORD) +PROJECT_DATABASE_PASSWORD: str = getenv("PROJECT_DATABASE_PASSWORD", DATABASE_PASSWORD) """ Password to be used by the auto generated project to connect to the database under testing. @@ -250,7 +276,7 @@ defaults to: `environ[DATABASE_PASSWORD]` """ -PROJECT_DATABASE_HOST: str = os.getenv("PROJECT_DATABASE_HOST", DATABASE_HOST) +PROJECT_DATABASE_HOST: str = getenv("PROJECT_DATABASE_HOST", DATABASE_HOST) """ Hostname where the database under testing is running in. @@ -258,7 +284,7 @@ defaults to: `environ[DATABASE_HOST]` """ -PROJECT_DATABASE_PORT: str = os.getenv("PROJECT_DATABASE_PORT", DATABASE_PORT) +PROJECT_DATABASE_PORT: str = getenv("PROJECT_DATABASE_PORT", DATABASE_PORT) """ Port at which the database under testing is exposed by the host. @@ -266,7 +292,7 @@ defaults to: `environ[DATABASE_PORT]` """ -SKIP_DATABASE_CERTIFICATE_VERIFICATION: bool = os.getenv("TG_TARGET_DB_TRUST_SERVER_CERTIFICATE", "no").lower() in ("yes", "true") +SKIP_DATABASE_CERTIFICATE_VERIFICATION: bool = getenv("TG_TARGET_DB_TRUST_SERVER_CERTIFICATE", "no").lower() in ("yes", "true") """ When True for supported SQL flavors, set up the SQLAlchemy connection to trust the database server certificate. @@ -275,7 +301,7 @@ defaults to: `True` """ -DEFAULT_TABLE_GROUPS_NAME: str = os.getenv("DEFAULT_TABLE_GROUPS_NAME", "default") +DEFAULT_TABLE_GROUPS_NAME: str = getenv("DEFAULT_TABLE_GROUPS_NAME", "default") """ Name assigned to the auto generated table group. @@ -283,7 +309,7 @@ defaults to: `default` """ -DEFAULT_TEST_SUITE_KEY: str = os.getenv("DEFAULT_TEST_SUITE_NAME", "default-suite-1") +DEFAULT_TEST_SUITE_KEY: str = getenv("DEFAULT_TEST_SUITE_NAME", "default-suite-1") """ Key to be assgined to the auto generated test suite. @@ -291,7 +317,7 @@ defaults to: `default-suite-1` """ -DEFAULT_TEST_SUITE_DESCRIPTION: str = os.getenv("DEFAULT_TEST_SUITE_DESCRIPTION", "default_suite_desc") +DEFAULT_TEST_SUITE_DESCRIPTION: str = getenv("DEFAULT_TEST_SUITE_DESCRIPTION", "default_suite_desc") """ Description for the auto generated test suite. @@ -299,7 +325,7 @@ defaults to: `default_suite_desc` """ -DEFAULT_PROFILING_TABLE_SET = os.getenv("DEFAULT_PROFILING_TABLE_SET", "") +DEFAULT_PROFILING_TABLE_SET = getenv("DEFAULT_PROFILING_TABLE_SET", "") """ Comma separated list of specific table names to include when running profiling for the project database. @@ -307,7 +333,7 @@ from env variable: `DEFAULT_PROFILING_TABLE_SET` """ -DEFAULT_PROFILING_INCLUDE_MASK = os.getenv("DEFAULT_PROFILING_INCLUDE_MASK", "%") +DEFAULT_PROFILING_INCLUDE_MASK = getenv("DEFAULT_PROFILING_INCLUDE_MASK", "%") """ A SQL filter supported by the project database's `LIKE` operator for table names to include. @@ -316,7 +342,7 @@ defaults to: `%` """ -DEFAULT_PROFILING_EXCLUDE_MASK = os.getenv("DEFAULT_PROFILING_EXCLUDE_MASK", "tmp%") +DEFAULT_PROFILING_EXCLUDE_MASK = getenv("DEFAULT_PROFILING_EXCLUDE_MASK", "tmp%") """ A SQL filter supported by the project database's `LIKE` operator for table names to exclude. @@ -325,7 +351,7 @@ defaults to: `tmp%` """ -DEFAULT_PROFILING_ID_COLUMN_MASK = os.getenv("DEFAULT_PROFILING_ID_COLUMN_MASK", "%id") +DEFAULT_PROFILING_ID_COLUMN_MASK = getenv("DEFAULT_PROFILING_ID_COLUMN_MASK", "%id") """ A SQL filter supported by the project database's `LIKE` operator representing ID columns. @@ -334,7 +360,7 @@ defaults to: `%id` """ -DEFAULT_PROFILING_SK_COLUMN_MASK = os.getenv("DEFAULT_PROFILING_SK_COLUMN_MASK", "%sk") +DEFAULT_PROFILING_SK_COLUMN_MASK = getenv("DEFAULT_PROFILING_SK_COLUMN_MASK", "%sk") """ A SQL filter supported by the project database's `LIKE` operator representing surrogate key columns. @@ -343,7 +369,7 @@ defaults to: `%sk` """ -DEFAULT_PROFILING_USE_SAMPLING: str = os.getenv("DEFAULT_PROFILING_USE_SAMPLING", "N") +DEFAULT_PROFILING_USE_SAMPLING: str = getenv("DEFAULT_PROFILING_USE_SAMPLING", "N") """ Toggle on to base profiling on a sample of records instead of the full table. Accepts `Y` or `N` @@ -352,7 +378,7 @@ defaults to: `N` """ -OBSERVABILITY_API_URL: str = os.getenv("OBSERVABILITY_API_URL", "") +OBSERVABILITY_API_URL: str = getenv("OBSERVABILITY_API_URL", "") """ API URL of your instance of Observability where to send events to for the project. @@ -362,7 +388,7 @@ from env variable: `OBSERVABILITY_API_URL` """ -OBSERVABILITY_API_KEY: str = os.getenv("OBSERVABILITY_API_KEY", "") +OBSERVABILITY_API_KEY: str = getenv("OBSERVABILITY_API_KEY", "") """ Authentication key with permissions to send events created in your instance of Observability. @@ -372,7 +398,7 @@ from env variable: `OBSERVABILITY_API_KEY` """ -OBSERVABILITY_VERIFY_SSL: bool = os.getenv("TG_EXPORT_TO_OBSERVABILITY_VERIFY_SSL", "yes").lower() in ("yes", "true") +OBSERVABILITY_VERIFY_SSL: bool = getenv("TG_EXPORT_TO_OBSERVABILITY_VERIFY_SSL", "yes").lower() in ("yes", "true") """ When False, exporting events to your instance of Observability will skip SSL verification. @@ -381,7 +407,7 @@ defaults to: `True` """ -OBSERVABILITY_EXPORT_LIMIT: int = int(os.getenv("TG_OBSERVABILITY_EXPORT_MAX_QTY", "5000")) +OBSERVABILITY_EXPORT_LIMIT: int = int(getenv("TG_OBSERVABILITY_EXPORT_MAX_QTY", "5000")) """ When exporting to your instance of Observability, the maximum number of events that will be sent to the events API on a single export. @@ -390,7 +416,7 @@ defaults to: `5000` """ -OBSERVABILITY_DEFAULT_COMPONENT_TYPE: str = os.getenv("OBSERVABILITY_DEFAULT_COMPONENT_TYPE", "dataset") +OBSERVABILITY_DEFAULT_COMPONENT_TYPE: str = getenv("OBSERVABILITY_DEFAULT_COMPONENT_TYPE", "dataset") """ When exporting to your instance of Observability, the type of event that will be sent to the events API. @@ -399,7 +425,7 @@ defaults to: `dataset` """ -OBSERVABILITY_DEFAULT_COMPONENT_KEY: str = os.getenv("OBSERVABILITY_DEFAULT_COMPONENT_KEY", "default") +OBSERVABILITY_DEFAULT_COMPONENT_KEY: str = getenv("OBSERVABILITY_DEFAULT_COMPONENT_KEY", "default") """ When exporting to your instance of Observability, the key sent to the events API to identify the components. @@ -410,13 +436,13 @@ CHECK_FOR_LATEST_VERSION: typing.Literal["pypi", "docker"] = typing.cast( typing.Literal["pypi", "docker"], - os.getenv("TG_RELEASE_CHECK", "pypi").lower(), + getenv("TG_RELEASE_CHECK", "pypi").lower(), ) """ Specifies whether the latest version check should be based on PyPI or DockerHub. """ -DOCKER_HUB_REPOSITORY: str = os.getenv( +DOCKER_HUB_REPOSITORY: str = getenv( "TESTGEN_DOCKER_HUB_REPO", "datakitchen/dataops-testgen", ) @@ -426,18 +452,18 @@ `docker`. """ -VERSION: str = os.getenv("TESTGEN_VERSION", None) +VERSION: str = getenv("TESTGEN_VERSION", None) """ Current deployed version. The value is displayed in the UI menu. """ -SUPPORT_EMAIL: str = os.getenv("TESTGEN_SUPPORT_EMAIL", "open-source-support@datakitchen.io") +SUPPORT_EMAIL: str = getenv("TESTGEN_SUPPORT_EMAIL", "open-source-support@datakitchen.io") """ Email for contacting DataKitchen support. """ -SSL_CERT_FILE: str = os.getenv("SSL_CERT_FILE", "") -SSL_KEY_FILE: str = os.getenv("SSL_KEY_FILE", "") +SSL_CERT_FILE: str = getenv("SSL_CERT_FILE", "") +SSL_KEY_FILE: str = getenv("SSL_KEY_FILE", "") """ File paths for SSL certificate and private key to support HTTPS. Both files must be provided. @@ -451,57 +477,57 @@ Mixpanel configuration """ -INSTANCE_ID: str | None = os.getenv("TG_INSTANCE_ID", None) +INSTANCE_ID: str | None = getenv("TG_INSTANCE_ID", None) """ Random ID that uniquely identifies the instance. """ -ANALYTICS_ENABLED: bool = os.getenv("TG_ANALYTICS", "yes").lower() in ("yes", "true") +ANALYTICS_ENABLED: bool = getenv("TG_ANALYTICS", "yes").lower() in ("yes", "true") """ Disables sending usage data when set to any value except "true" and "yes". Defaults to "yes" """ -ANALYTICS_JOB_SOURCE: str = os.getenv("TG_JOB_SOURCE", "CLI") +ANALYTICS_JOB_SOURCE: str = getenv("TG_JOB_SOURCE", "CLI") """ Identifies the job trigger for analytics purposes. """ -JWT_HASHING_KEY_B64: str = os.getenv("TG_JWT_HASHING_KEY") +JWT_HASHING_KEY_B64: str = getenv("TG_JWT_HASHING_KEY") """ Random key used to sign/verify the authentication token """ -ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT: int = os.getenv("TG_ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT", 50) +ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT: int = getenv("TG_ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT", 50) """ Limit the number of records used to generate the PDF with test results and hygiene issue reports. """ -EMAIL_FROM_ADDRESS: str | None = os.getenv("TG_EMAIL_FROM_ADDRESS") +EMAIL_FROM_ADDRESS: str | None = getenv("TG_EMAIL_FROM_ADDRESS") """ Email: Sender address """ -SMTP_ENDPOINT: str | None = os.getenv("TG_SMTP_ENDPOINT") +SMTP_ENDPOINT: str | None = getenv("TG_SMTP_ENDPOINT") """ Email: SMTP endpoint """ -SMTP_PORT: int | None = int(os.getenv("TG_SMTP_PORT", 0)) or None +SMTP_PORT: int | None = int(getenv("TG_SMTP_PORT", 0)) or None """ Email: SMTP port """ -SMTP_USERNAME: str | None = os.getenv("TG_SMTP_USERNAME") +SMTP_USERNAME: str | None = getenv("TG_SMTP_USERNAME") """ Email: SMTP username """ -SMTP_PASSWORD: str | None = os.getenv("TG_SMTP_PASSWORD") +SMTP_PASSWORD: str | None = getenv("TG_SMTP_PASSWORD") """ Email: SMTP password """ -MCP_PORT: int = int(os.getenv("TG_MCP_PORT", "8510")) +MCP_PORT: int = int(getenv("TG_MCP_PORT", "8510")) """ Port for the MCP server. @@ -509,7 +535,7 @@ defaults to: `8510` """ -MCP_HOST: str = os.getenv("TG_MCP_HOST", "0.0.0.0") # noqa: S104 +MCP_HOST: str = getenv("TG_MCP_HOST", "0.0.0.0") # noqa: S104 """ Host for the MCP server. @@ -517,7 +543,7 @@ defaults to: `0.0.0.0` """ -MCP_ENABLED: bool = os.getenv("TG_MCP_ENABLED", "no").lower() in ("yes", "true") +MCP_ENABLED: bool = getenv("TG_MCP_ENABLED", "no").lower() in ("yes", "true") """ Enable the MCP server when running `testgen run-app all`. diff --git a/testgen/ui/app.py b/testgen/ui/app.py index 5ed2bc72..1e10372a 100644 --- a/testgen/ui/app.py +++ b/testgen/ui/app.py @@ -1,4 +1,5 @@ import logging +import os from urllib.parse import urlparse import streamlit as st @@ -6,6 +7,7 @@ from testgen import settings from testgen.common import version_service from testgen.common.docker_service import check_basic_configuration +from testgen.common.standalone_postgres import STANDALONE_URI_ENV_VAR, ensure_standalone_setup, is_standalone_mode from testgen.common.models import get_current_session, with_database_session from testgen.common.models.project import Project from testgen.ui import bootstrap @@ -14,6 +16,8 @@ from testgen.ui.services import javascript_service from testgen.ui.session import session +if is_standalone_mode() and (standalone_uri := os.environ.get(STANDALONE_URI_ENV_VAR)): + ensure_standalone_setup(standalone_uri) @with_database_session def render(log_level: int = logging.INFO): @@ -33,10 +37,11 @@ def render(log_level: int = logging.INFO): if not session.auth: session.auth = application.auth_class() - status_ok, message = check_basic_configuration() - if not status_ok: - st.markdown(f":red[{message}]") - return + if not is_standalone_mode(): + status_ok, message = check_basic_configuration() + if not status_ok: + st.markdown(f":red[{message}]") + return set_locale() diff --git a/testgen/ui/components/frontend/js/display_utils.js b/testgen/ui/components/frontend/js/display_utils.js index 8dc0c9f5..3d186d6f 100644 --- a/testgen/ui/components/frontend/js/display_utils.js +++ b/testgen/ui/components/frontend/js/display_utils.js @@ -38,7 +38,7 @@ function formatDuration( function formatDurationSeconds( /** @type number */ totalSeconds, ) { - if (!totalSeconds) { + if (totalSeconds == null || totalSeconds < 0) { return '--'; } diff --git a/testgen/ui/static/js/display_utils.js b/testgen/ui/static/js/display_utils.js index 8dc0c9f5..3d186d6f 100644 --- a/testgen/ui/static/js/display_utils.js +++ b/testgen/ui/static/js/display_utils.js @@ -38,7 +38,7 @@ function formatDuration( function formatDurationSeconds( /** @type number */ totalSeconds, ) { - if (!totalSeconds) { + if (totalSeconds == null || totalSeconds < 0) { return '--'; } From 01c4d59884b15c45e4b93781e70bd85880bac2ee Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 20 Apr 2026 18:45:44 -0400 Subject: [PATCH 2/2] refactor: upgrade to python 3.13 --- README.md | 2 +- deploy/install_linuxodbc.sh | 9 +-- deploy/testgen-base.dockerfile | 12 ++-- deploy/testgen.dockerfile | 10 ++- docs/local_development.md | 4 +- pyproject.toml | 25 +++---- .../queries/refresh_data_chars_query.py | 2 +- testgen/commands/run_refresh_data_chars.py | 2 +- testgen/common/database/database_service.py | 66 ++++++++++--------- .../database/flavor/mssql_flavor_service.py | 4 +- .../flavor/postgresql_flavor_service.py | 1 + .../flavor/redshift_flavor_service.py | 32 ++++++++- testgen/common/models/__init__.py | 3 +- testgen/common/models/entity.py | 22 ++++--- testgen/common/models/profiling_run.py | 10 +-- testgen/common/models/project.py | 2 +- testgen/common/models/scores.py | 27 ++++---- testgen/common/models/settings.py | 17 ++--- testgen/common/models/test_definition.py | 2 +- testgen/common/models/test_run.py | 8 +-- testgen/common/notifications/test_run.py | 2 +- testgen/common/process_service.py | 33 ++++++---- testgen/common/read_file.py | 2 +- .../frontend/js/components/score_breakdown.js | 6 +- testgen/ui/pdf/dataframe_table.py | 5 +- testgen/ui/services/database_service.py | 6 +- .../static/js/components/score_breakdown.js | 6 +- testgen/ui/views/connections.py | 10 ++- .../test_test_run_notifications.py | 8 ++- 29 files changed, 205 insertions(+), 133 deletions(-) diff --git a/README.md b/README.md index ef49135b..4b6a5726 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ source venv/bin/activate _On Windows_ ```powershell -py -3.12 -m venv venv +py -3.13 -m venv venv venv\Scripts\activate ``` diff --git a/deploy/install_linuxodbc.sh b/deploy/install_linuxodbc.sh index e0f4080d..d73eea48 100755 --- a/deploy/install_linuxodbc.sh +++ b/deploy/install_linuxodbc.sh @@ -31,19 +31,16 @@ fi openssl x509 -inform DER -in cert.crt -out /usr/local/share/ca-certificates/microsoft_tls_g2_ecc_ocsp_02.pem update-ca-certificates - # Download the desired packages + # Download the ODBC driver (msodbcsql18) only — mssql-tools18 (sqlcmd, bcp, iusql) + # is not needed at runtime and triggers false-positive secret findings in security scans curl -O https://download.microsoft.com/download/9dcab408-e0d4-4571-a81a-5a0951e3445f/msodbcsql18_18.6.1.1-1_$architecture.apk - curl -O https://download.microsoft.com/download/b60bb8b6-d398-4819-9950-2e30cf725fb0/mssql-tools18_18.6.1.1-1_$architecture.apk # Verify signature, if 'gpg' is missing install it using 'apk add gnupg': curl -O https://download.microsoft.com/download/9dcab408-e0d4-4571-a81a-5a0951e3445f/msodbcsql18_18.6.1.1-1_$architecture.sig - curl -O https://download.microsoft.com/download/b60bb8b6-d398-4819-9950-2e30cf725fb0/mssql-tools18_18.6.1.1-1_$architecture.sig curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > microsoft.gpg gpgv --keyring ./microsoft.gpg msodbcsql18_*.sig msodbcsql18_*.apk - gpgv --keyring ./microsoft.gpg mssql-tools18_*.sig mssql-tools18_*.apk - # Install the packages + # Install the ODBC driver apk add --no-cache --allow-untrusted msodbcsql18_18.6.1.1-1_$architecture.apk - apk add --no-cache --allow-untrusted mssql-tools18_18.6.1.1-1_$architecture.apk ) diff --git a/deploy/testgen-base.dockerfile b/deploy/testgen-base.dockerfile index 0a297555..fd207bcd 100644 --- a/deploy/testgen-base.dockerfile +++ b/deploy/testgen-base.dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12-alpine3.23 +FROM python:3.13-alpine3.23 ENV LANG=C.UTF-8 ENV LC_ALL=C.UTF-8 @@ -47,12 +47,12 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip==26.0 # We download the wheel for the correct arch, then extract it directly into site-packages # (wheels are zip files). gcompat provides the glibc shim needed at runtime. RUN ARCH=$(uname -m) && \ - pip download --platform manylinux2014_${ARCH} --python-version 3.12 --only-binary :all: \ + pip download --platform manylinux2014_${ARCH} --python-version 3.13 --only-binary :all: \ --no-deps -d /tmp/wheels hdbcli==2.25.31 && \ - python3 -m zipfile -e /tmp/wheels/hdbcli-*.whl /dk/lib/python3.12/site-packages/ && \ + python3 -m zipfile -e /tmp/wheels/hdbcli-*.whl /dk/lib/python3.13/site-packages/ && \ # Copy dist-info to system site-packages so pip sees hdbcli as installed during # dependency resolution (sqlalchemy-hana transitively depends on hdbcli~=2.10) - cp -r /dk/lib/python3.12/site-packages/hdbcli-*.dist-info \ + cp -r /dk/lib/python3.13/site-packages/hdbcli-*.dist-info \ "$(python3 -c 'import sysconfig; print(sysconfig.get_path("purelib"))')"/ && \ rm -rf /tmp/wheels @@ -78,4 +78,8 @@ RUN apk del \ unixodbc-dev \ apache-arrow-dev +# Remove interactive ODBC tools — not needed at runtime, and iusql triggers +# false-positive secret detection in security scanners (SECRET-3010) +RUN rm -f /usr/bin/iusql /usr/bin/isql + RUN rm -rf /root/.cache/pip /tmp/dk/install_linuxodbc.sh diff --git a/deploy/testgen.dockerfile b/deploy/testgen.dockerfile index 6708fd67..f5a58270 100644 --- a/deploy/testgen.dockerfile +++ b/deploy/testgen.dockerfile @@ -7,11 +7,15 @@ ARG TESTGEN_VERSION ARG TESTGEN_DOCKER_HUB_REPO ARG TESTGEN_SUPPORT_EMAIL -ENV PYTHONPATH=/dk/lib/python3.12/site-packages +ENV PYTHONPATH=/dk/lib/python3.13/site-packages ENV PATH=$PATH:/dk/bin RUN apk upgrade +# Remove interactive ODBC tools — not needed at runtime, and iusql triggers +# false-positive secret detection in security scanners (SECRET-3010) +RUN rm -f /usr/bin/iusql /usr/bin/isql + # Now install everything (hdbcli is pre-installed in the base image via manual wheel extraction) COPY . /tmp/dk/ RUN sed -i '/hdbcli/d' /tmp/dk/pyproject.toml /tmp/dk/testgen/pyproject.toml 2>/dev/null; \ @@ -20,7 +24,7 @@ RUN sed -i '/hdbcli/d' /tmp/dk/pyproject.toml /tmp/dk/testgen/pyproject.toml 2>/ # Generate third-party license notices from installed packages RUN pip install --no-cache-dir pip-licenses \ && SCRIPT=$(find /tmp/dk -name generate_third_party_notices.py | head -1) \ - && PYTHONPATH=/dk/lib/python3.12/site-packages python3 "$SCRIPT" --output /dk/THIRD-PARTY-NOTICES \ + && PYTHONPATH=/dk/lib/python3.13/site-packages python3 "$SCRIPT" --output /dk/THIRD-PARTY-NOTICES \ && pip uninstall -y pip-licenses RUN rm -Rf /tmp/dk /root/.cache/pip @@ -31,7 +35,7 @@ RUN addgroup -S testgen && adduser -S testgen -G testgen # Streamlit has to be able to write to these dirs RUN mkdir /var/lib/testgen -RUN chown -R testgen:testgen /var/lib/testgen /dk/lib/python3.12/site-packages/streamlit/static /dk/lib/python3.12/site-packages/testgen/ui/components/frontend +RUN chown -R testgen:testgen /var/lib/testgen /dk/lib/python3.13/site-packages/streamlit/static /dk/lib/python3.13/site-packages/testgen/ui/components/frontend ENV TESTGEN_VERSION=${TESTGEN_VERSION} ENV TESTGEN_DOCKER_HUB_REPO=${TESTGEN_DOCKER_HUB_REPO} diff --git a/docs/local_development.md b/docs/local_development.md index bbe49c78..cff533ec 100644 --- a/docs/local_development.md +++ b/docs/local_development.md @@ -27,13 +27,13 @@ From the root of your local repository, create and activate a virtual environmen _On Linux/Mac_ ```shell -python3.12 -m venv venv +python3.13 -m venv venv source venv/bin/activate ``` _On Windows_ ```powershell -py -3.12 -m venv venv +py -3.13 -m venv venv venv\Scripts\activate ``` diff --git a/pyproject.toml b/pyproject.toml index bb6e7b49..43851025 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,30 +21,31 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 5 - Production/Stable", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: System :: Monitoring", ] keywords = [ "dataops", "data", "quality", "testing", "database", "profiling" ] requires-python = ">=3.12" dependencies = [ - "PyYAML==6.0.1", - "click==8.1.3", - "sqlalchemy==1.4.46", - "databricks-sql-connector==2.9.3", + "PyYAML==6.0.3", + "click==8.3.1", + "sqlalchemy==2.0.48", + "databricks-sql-connector==4.2.5", + "databricks-sqlalchemy==2.0.9", "databricks-sdk>=0.20.0", "snowflake-sqlalchemy==1.9.0", - "sqlalchemy-bigquery==1.14.1", + "sqlalchemy-bigquery==1.16.0", "oracledb==3.4.0", "hdbcli==2.25.31", - "sqlalchemy-hana==2.1.0", - "pyodbc==5.0.0", - "psycopg2-binary==2.9.9", + "sqlalchemy-hana==4.4.0", + "pyodbc==5.2.0", + "psycopg2-binary==2.9.11", "pycryptodome==3.21", "prettytable==3.7.0", "requests_extensions==1.1.3", - "numpy==1.26.4", - "pandas==2.1.4", + "numpy==2.1.3", + "pandas==2.2.3", "streamlit==1.55.0", "streamlit-extras==0.3.0", "streamlit-aggrid==0.3.4.post3", @@ -169,7 +170,7 @@ filterwarnings = [ # for an explanation of their functionality. # WARNING: When changing mypy configurations, be sure to test them after removing your .mypy_cache [tool.mypy] -python_version = "3.12" +python_version = "3.13" check_untyped_defs = true disallow_untyped_decorators = true disallow_untyped_defs = true diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py index 9964a2d4..e5d72fa5 100644 --- a/testgen/commands/queries/refresh_data_chars_query.py +++ b/testgen/commands/queries/refresh_data_chars_query.py @@ -113,7 +113,7 @@ def get_row_counts(self, table_names: Iterable[str]) -> list[tuple[str, None]]: schema = self.table_group.table_group_schema quote = self.flavor_service.quote_character count_queries = [ - f"SELECT '{table}', COUNT(*) FROM {quote}{schema}{quote}.{quote}{table}{quote}" + f"SELECT '{table}' AS table_name, COUNT(*) AS row_count FROM {quote}{schema}{quote}.{quote}{table}{quote}" for table in table_names ] chunked_queries = chunk_queries(count_queries, " UNION ALL ", self.connection.max_query_chars) diff --git a/testgen/commands/run_refresh_data_chars.py b/testgen/commands/run_refresh_data_chars.py index a972f7f1..94f9b3e0 100644 --- a/testgen/commands/run_refresh_data_chars.py +++ b/testgen/commands/run_refresh_data_chars.py @@ -35,7 +35,7 @@ def run_data_chars_refresh(connection: Connection, table_group: TableGroup, run_ count_queries, use_target_db=True, max_threads=connection.max_threads, ) - count_map = dict(count_results) + count_map = {row["table_name"]: row["row_count"] for row in count_results} for column in data_chars: column.record_ct = count_map.get(column.table_name) diff --git a/testgen/common/database/database_service.py b/testgen/common/database/database_service.py index dae77d6d..eba7d73b 100644 --- a/testgen/common/database/database_service.py +++ b/testgen/common/database/database_service.py @@ -2,6 +2,7 @@ import csv import importlib import logging +import math import re from collections.abc import Callable, Iterable from contextlib import suppress @@ -10,11 +11,10 @@ from urllib.parse import quote_plus import psycopg2.sql -from sqlalchemy import create_engine, text -from sqlalchemy.engine import LegacyRow, RowMapping -from sqlalchemy.engine.base import Connection, Engine +from sqlalchemy import Connection, Engine, Row, create_engine, text +from sqlalchemy.engine import RowMapping from sqlalchemy.exc import ProgrammingError, SQLAlchemyError -from sqlalchemy.pool.base import _ConnectionFairy +from sqlalchemy.pool import PoolProxiedConnection from testgen import settings from testgen.common.credentials import ( @@ -32,8 +32,9 @@ SQLFlavor, resolve_connection_params, ) -from testgen.common.standalone_postgres import get_connection_string as get_standalone_connection_string, is_standalone_mode from testgen.common.read_file import get_template_files +from testgen.common.standalone_postgres import get_connection_string as get_standalone_connection_string +from testgen.common.standalone_postgres import is_standalone_mode from testgen.utils import get_exception_message LOG = logging.getLogger("testgen") @@ -103,12 +104,14 @@ def create_database( ) -> None: LOG.debug("DB operation: create_database on App database (User type = database_admin)") + # DDL like CREATE/DROP DATABASE cannot run inside a transaction. + # Use AUTOCOMMIT isolation so each statement commits immediately. connection = _init_db_connection( user_override=params["TESTGEN_ADMIN_USER"], password_override=params["TESTGEN_ADMIN_PASSWORD"], user_type="database_admin", ) - connection.execute("commit") + connection = connection.execution_options(isolation_level="AUTOCOMMIT") with connection: if drop_existing: @@ -118,20 +121,16 @@ def create_database( ), {"database_name": database_name}, ) - connection.execute("commit") - connection.execute(f"DROP DATABASE IF EXISTS {database_name}") - connection.execute("commit") + connection.execute(text(f"DROP DATABASE IF EXISTS {database_name}")) if drop_users_and_roles: if user := params.get("TESTGEN_USER"): - connection.execute(f"DROP USER IF EXISTS {user}") + connection.execute(text(f"DROP USER IF EXISTS {user}")) if report_user := params.get("TESTGEN_REPORT_USER"): - connection.execute(f"DROP USER IF EXISTS {report_user}") - connection.execute("DROP ROLE IF EXISTS testgen_execute_role") - connection.execute("DROP ROLE IF EXISTS testgen_report_role") - connection.execute("commit") + connection.execute(text(f"DROP USER IF EXISTS {report_user}")) + connection.execute(text("DROP ROLE IF EXISTS testgen_execute_role")) + connection.execute(text("DROP ROLE IF EXISTS testgen_report_role")) with suppress(ProgrammingError): - connection.execute(f"CREATE DATABASE {database_name}") - connection.close() + connection.execute(text(f"CREATE DATABASE {database_name}")) def execute_db_queries( @@ -150,7 +149,6 @@ def execute_db_queries( LOG.debug("No queries to process") for index, (query, params) in enumerate(queries): LOG.debug(f"Query {index + 1} of {len(queries)}: {query}") - transaction = connection.begin() result = connection.execute(text(query), params) row_counts.append(result.rowcount) if result.rowcount == -1: @@ -163,7 +161,7 @@ def execute_db_queries( except Exception: return_values.append(None) - transaction.commit() + connection.commit() LOG.debug(message) return return_values, row_counts @@ -180,12 +178,12 @@ def fetch_from_db_threaded( use_target_db: bool = False, max_threads: int = 4, progress_callback: Callable[[ThreadedProgress], None] | None = None, -) -> tuple[list[LegacyRow], list[str], dict[int, str]]: +) -> tuple[list[RowMapping], list[str], dict[int, str]]: LOG.debug(f"DB operation: fetch_from_db_threaded ({len(queries)}) on {'Target' if use_target_db else 'App'} database (User type = normal)") - def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[LegacyRow], list[str], int, str | None]: + def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[RowMapping], list[str], int, str | None]: LOG.debug(f"Query: {query}") - row_data: list[LegacyRow] = [] + row_data: list[RowMapping] = [] column_names: list[str] = [] error = None @@ -193,7 +191,7 @@ def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[Legacy with _init_db_connection(use_target_db) as connection: result = connection.execute(text(query), params) LOG.debug(f"{result.rowcount} records retrieved") - row_data = result.fetchall() + row_data = result.mappings().fetchall() column_names = list(result.keys()) except Exception as e: error = get_exception_message(e) @@ -201,7 +199,7 @@ def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[Legacy return row_data, column_names, index, error - result_data: list[LegacyRow] = [] + result_data: list[RowMapping] = [] result_columns: list[str] = [] error_data: dict[int, str] = {} @@ -241,7 +239,7 @@ def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[Legacy def fetch_list_from_db( query: str, params: dict | None = None, use_target_db: bool = False -) -> tuple[list[LegacyRow], list[str]]: +) -> tuple[list[Row], list[str]]: LOG.debug(f"DB operation: fetch_list_from_db on {'Target' if use_target_db else 'App'} database (User type = normal)") with _init_db_connection(use_target_db) as connection: @@ -263,11 +261,10 @@ def fetch_dict_from_db( LOG.debug(f"Query: {query}") result = connection.execute(text(query), params) LOG.debug(f"{result.rowcount} records retrieved") - # Creates list of dictionaries so records are addressible by column name - return [row._mapping for row in result] + return result.mappings().all() -def write_to_app_db(data: list[LegacyRow], column_names: Iterable[str], table_name: str) -> None: +def write_to_app_db(data: list[Row], column_names: Iterable[str], table_name: str) -> None: LOG.debug("DB operation: write_to_app_db on App database (User type = normal)") # use_raw is required to make use of the copy_expert method for fast batch ingestion @@ -275,9 +272,18 @@ def write_to_app_db(data: list[LegacyRow], column_names: Iterable[str], table_na cursor = connection.cursor() # Write List to CSV in memory + # Sanitize NaN → None: some DB connectors (e.g. Databricks via Arrow) return + # float('nan') for NULL integers. CSV would serialize these as "nan" which + # PostgreSQL rejects for numeric columns. + # RowMapping objects iterate over keys, not values — extract values explicitly. + def _row_values(row): + values = row.values() if isinstance(row, RowMapping) else row + return tuple(None if isinstance(v, float) and math.isnan(v) else v for v in values) + + sanitized = [_row_values(row) for row in data] buffer = FilteredStringIO(["\x00"]) writer = csv.writer(buffer, quoting=csv.QUOTE_MINIMAL) - writer.writerows(data) + writer.writerows(sanitized) buffer.seek(0) # List should have same column names as destination table, though not all columns in table are required @@ -362,7 +368,7 @@ def _init_app_db_connection( password_override: str | None = None, user_type: UserType = "normal", use_raw: bool = False, -) -> Connection | _ConnectionFairy: +) -> Connection | PoolProxiedConnection: database_name = "postgres" if user_type == "database_admin" else get_tg_db() is_admin = user_type == "database_admin" or user_type == "schema_admin" @@ -399,7 +405,7 @@ def _init_app_db_connection( try: schema_name = "public" if is_admin else get_tg_schema() if use_raw: - connection: _ConnectionFairy = engine.raw_connection() + connection: PoolProxiedConnection = engine.raw_connection() with connection.cursor() as cursor: cursor.execute( "SET SEARCH_PATH = %(schema_name)s", diff --git a/testgen/common/database/flavor/mssql_flavor_service.py b/testgen/common/database/flavor/mssql_flavor_service.py index 70ee3d11..570c8b5c 100644 --- a/testgen/common/database/flavor/mssql_flavor_service.py +++ b/testgen/common/database/flavor/mssql_flavor_service.py @@ -1,5 +1,3 @@ -from urllib.parse import quote_plus - from sqlalchemy.engine import URL from testgen import settings @@ -17,7 +15,7 @@ def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> connection_url = URL.create( self.url_scheme, username=params.username, - password=quote_plus(params.password or ""), + password=params.password or "", host=params.host, port=int(params.port or 1443), database=params.dbname, diff --git a/testgen/common/database/flavor/postgresql_flavor_service.py b/testgen/common/database/flavor/postgresql_flavor_service.py index 99f968c6..011ab05a 100644 --- a/testgen/common/database/flavor/postgresql_flavor_service.py +++ b/testgen/common/database/flavor/postgresql_flavor_service.py @@ -7,6 +7,7 @@ class PostgresqlFlavorService(RedshiftFlavorService): escaped_underscore = "\\_" + url_scheme = "postgresql" def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> str: if params.host.startswith("/"): diff --git a/testgen/common/database/flavor/redshift_flavor_service.py b/testgen/common/database/flavor/redshift_flavor_service.py index 3b6c6e6a..77459a0f 100644 --- a/testgen/common/database/flavor/redshift_flavor_service.py +++ b/testgen/common/database/flavor/redshift_flavor_service.py @@ -1,13 +1,39 @@ from urllib.parse import quote_plus -from testgen.common.database.flavor.flavor_service import FlavorService, ResolvedConnectionParams +from sqlalchemy.dialects import registry as _dialect_registry +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.engine import Engine +from sqlalchemy.engine import create_engine as sqlalchemy_create_engine + +from testgen.common.database.flavor.flavor_service import ( + ConnectionParams, + FlavorService, + ResolvedConnectionParams, + resolve_connection_params, +) + + +class _RedshiftDialect(PGDialect_psycopg2): + """PostgreSQL dialect patched for Redshift compatibility. + + Redshift doesn't support ``standard_conforming_strings``, which SA 2.0's + PostgreSQL dialect queries during ``initialize()``. This subclass stubs out + the check so connections succeed. + """ + name = "redshift_pg" + + def _set_backslash_escapes(self, connection): + self._backslash_escapes = False + + +# Register so ``redshift_pg://`` URLs resolve to this dialect +_dialect_registry.register("redshift_pg", __name__, "_RedshiftDialect") class RedshiftFlavorService(FlavorService): escaped_underscore = "\\\\_" - url_scheme = "postgresql" + url_scheme = "redshift_pg" def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> str: - # STANDARD FORMAT: strConnect = 'flavor://username:password@host:port/database' return f"{self.url_scheme}://{params.username}:{quote_plus(params.password)}@{params.host}:{params.port}/{params.dbname}" diff --git a/testgen/common/models/__init__.py b/testgen/common/models/__init__.py index 898090dd..21dcc448 100644 --- a/testgen/common/models/__init__.py +++ b/testgen/common/models/__init__.py @@ -5,7 +5,8 @@ import urllib.parse from sqlalchemy import create_engine -from sqlalchemy.orm import DeclarativeBase, Session as SQLAlchemySession, sessionmaker +from sqlalchemy.orm import DeclarativeBase, sessionmaker +from sqlalchemy.orm import Session as SQLAlchemySession from testgen import settings diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 3d7560de..5f1ad5bb 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -12,9 +12,15 @@ from testgen.common.models import Base, get_current_session from testgen.utils import is_uuid4, make_json_safe +def _hash_clause(x): + # Don't use literal_binds=True — SA 2.0 can't render UUID POSTCOMPILE IN-lists + # that way and raises CompileError when Streamlit hashes cached args. + compiled = x.compile() + return f"{compiled}|{compiled.params}" + ENTITY_HASH_FUNCS = { - BinaryExpression: lambda x: str(x.compile(compile_kwargs={"literal_binds": True})), - BooleanClauseList: lambda x: str(x.compile(compile_kwargs={"literal_binds": True})), + BinaryExpression: _hash_clause, + BooleanClauseList: _hash_clause, tuple: lambda x: [str(y) for y in x], } @@ -68,13 +74,13 @@ def _get_columns( select_columns = [ getattr(cls, col, None) or getattr(join_target, col) if isinstance(col, str) else col for col in columns ] - query = select(select_columns).join(join_target, join_clause) + query = select(*select_columns).join(join_target, join_clause) else: select_columns = [getattr(cls, col) if isinstance(col, str) else col for col in columns] - query = select(select_columns) + query = select(*select_columns) query = query.where(get_by_column == identifier) - return get_current_session().execute(query).first() + return get_current_session().execute(query).mappings().first() @classmethod @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) @@ -100,14 +106,14 @@ def _select_columns_where( select_columns = [ getattr(cls, col, None) or getattr(join_target, col) if isinstance(col, str) else col for col in columns ] - query = select(select_columns).join(join_target, join_clause) + query = select(*select_columns).join(join_target, join_clause) else: select_columns = [getattr(cls, col) if isinstance(col, str) else col for col in columns] - query = select(select_columns) + query = select(*select_columns) order_by = order_by or cls._default_order_by query = query.where(*clauses).order_by(*order_by) - return get_current_session().execute(query).all() + return get_current_session().execute(query).mappings().all() @classmethod def has_running_process(cls, ids: list[str]) -> bool: diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 41ae7e16..e2c15f41 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -113,9 +113,9 @@ def get_minimal(cls, run_id: str | UUID) -> ProfilingRunMinimal | None: return None query = ( - select(cls._minimal_columns).join(TableGroup, cls.table_groups_id == TableGroup.id).where(cls.id == run_id) + select(*cls._minimal_columns).join(TableGroup, cls.table_groups_id == TableGroup.id).where(cls.id == run_id) ) - result = get_current_session().execute(query).first() + result = get_current_session().execute(query).mappings().first() return ProfilingRunMinimal(**result) if result else None @classmethod @@ -126,7 +126,7 @@ def get_latest_run(cls, project_code: str) -> LatestProfilingRun | None: .order_by(desc(ProfilingRun.profiling_starttime)) .limit(1) ) - result = get_current_session().execute(query).first() + result = get_current_session().execute(query).mappings().first() if result: return LatestProfilingRun(str(result["id"]), result["profiling_starttime"]) return None @@ -137,12 +137,12 @@ def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[ProfilingRunMinimal]: query = ( - select(cls._minimal_columns) + select(*cls._minimal_columns) .join(TableGroup, cls.table_groups_id == TableGroup.id) .where(*clauses) .order_by(*order_by) ) - results = get_current_session().execute(query).all() + results = get_current_session().execute(query).mappings().all() return [ProfilingRunMinimal(**row) for row in results] @classmethod diff --git a/testgen/common/models/project.py b/testgen/common/models/project.py index eedfb13f..f0bb116b 100644 --- a/testgen/common/models/project.py +++ b/testgen/common/models/project.py @@ -89,7 +89,7 @@ def get_summary(cls, project_code: str) -> ProjectSummary | None: db_session = get_current_session() result = db_session.execute(text(query), {"project_code": project_code}).first() - return ProjectSummary(**result, project_code=project_code) if result else None + return ProjectSummary(**result._mapping, project_code=project_code) if result else None @classmethod def is_in_use(cls, ids: list[str]) -> bool: diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index 788ee00b..6eee93c3 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -76,26 +76,29 @@ class ScoreDefinition(Base): cde_score: bool = Column(Boolean, default=False, nullable=False) category: ScoreCategory | None = Column(Enum(ScoreCategory), nullable=True) - criteria: ScoreDefinitionCriteria = relationship( + # Note: avoid `Mapped[...]`-style or `Iterable[...]` annotations on these — + # with DeclarativeBase + `__allow_unmapped__=True`, they confuse SA's + # uselist inference and cause `self.results` to be treated as a scalar. + criteria = relationship( "ScoreDefinitionCriteria", cascade="all, delete-orphan", lazy="select", uselist=False, single_parent=True, ) - results: Iterable[ScoreDefinitionResult] = relationship( + results = relationship( "ScoreDefinitionResult", cascade="all, delete-orphan", order_by="ScoreDefinitionResult.category", lazy="joined", ) - breakdown: Iterable[ScoreDefinitionBreakdownItem] = relationship( + breakdown = relationship( "ScoreDefinitionBreakdownItem", cascade="all, delete-orphan", order_by="ScoreDefinitionBreakdownItem.impact.desc()", lazy="select", ) - history: Iterable[ScoreDefinitionResultHistoryEntry] = relationship( + history = relationship( "ScoreDefinitionResultHistoryEntry", order_by="ScoreDefinitionResultHistoryEntry.last_run_time.asc()", cascade="all, delete-orphan", @@ -241,10 +244,10 @@ def as_score_card(self) -> ScoreCard: filters = " AND ".join(self._get_raw_query_filters()) overall_scores = get_current_session().execute( - read_template_sql_file( + text(read_template_sql_file( overall_score_query_template_file, sub_directory="score_cards", - ).replace("{filters}", filters) + ).replace("{filters}", filters)) ).mappings().first() or {} categories_scores = [] @@ -252,10 +255,10 @@ def as_score_card(self) -> ScoreCard: categories_scores = [ dict(result) for result in get_current_session().execute( - read_template_sql_file( + text(read_template_sql_file( categories_query_template_file, sub_directory="score_cards", - ).replace("{category}", category.value).replace("{filters}", filters) + ).replace("{category}", category.value).replace("{filters}", filters)) ).mappings().all() ] @@ -359,7 +362,7 @@ def get_score_card_breakdown( .replace("{records_count_filters}", records_count_filters) .replace("{non_null_columns}", ", ".join(non_null_columns)) ) - results = get_current_session().execute(query).mappings().all() + results = get_current_session().execute(text(query)).mappings().all() return [dict(row) for row in results] @@ -499,7 +502,7 @@ class ScoreDefinitionCriteria(Base): definition_id: str = Column(postgresql.UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE")) operand: Literal["AND", "OR"] = Column(String, nullable=False, default="AND") group_by_field: bool = Column(Boolean, nullable=False, default=True) - filters: list[ScoreDefinitionFilter] = relationship( + filters = relationship( "ScoreDefinitionFilter", cascade="all, delete-orphan", lazy="joined", @@ -578,7 +581,7 @@ class ScoreDefinitionFilter(Base): nullable=True, default=None, ) - next_filter: ScoreDefinitionFilter = relationship( + next_filter = relationship( "ScoreDefinitionFilter", cascade="all, delete-orphan", lazy="joined", @@ -681,7 +684,7 @@ class ScoreDefinitionResultHistoryEntry(Base): score: float = Column(Float, nullable=True) last_run_time: datetime = Column(DateTime(timezone=False), nullable=False, primary_key=True) - definition: ScoreDefinition = relationship("ScoreDefinition", back_populates="history") + definition = relationship("ScoreDefinition", back_populates="history") def add_as_cutoff(self): """ diff --git a/testgen/common/models/settings.py b/testgen/common/models/settings.py index f98b1565..b66b7c23 100644 --- a/testgen/common/models/settings.py +++ b/testgen/common/models/settings.py @@ -1,7 +1,7 @@ from typing import Any -from sqlalchemy import Column, String -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import Column, String, select +from sqlalchemy.dialects.postgresql import JSONB, insert as pg_insert from testgen.common.models import Base, get_current_session @@ -21,9 +21,9 @@ class PersistedSetting(Base): @classmethod def get(cls, key: str, default=NO_DEFAULT) -> Any: # This caches all the settings in the session, so it hits the database only once - get_current_session().query(cls).all() + get_current_session().execute(select(cls)).scalars().all() - if ps := get_current_session().query(cls).filter_by(key=key).first(): + if ps := get_current_session().execute(select(cls).filter_by(key=key)).scalars().first(): return ps.value elif default is NO_DEFAULT: raise SettingNotFound(f"Setting '{key}' not found") @@ -32,11 +32,12 @@ def get(cls, key: str, default=NO_DEFAULT) -> Any: @classmethod def set(cls, key: str, value: Any): + # Atomic upsert: avoids the check-then-insert race that bites when multiple + # Streamlit reruns or sibling processes (UI + scheduler) touch the same key. session = get_current_session() - if ps := session.query(cls).filter_by(key=key).first(): - ps.value = value - else: - session.add(cls(key=key, value=value)) + stmt = pg_insert(cls).values(key=key, value=value) + stmt = stmt.on_conflict_do_update(index_elements=["key"], set_={"value": stmt.excluded.value}) + session.execute(stmt) session.flush() def __repr__(self): diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index e9e2651c..2748777b 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -383,7 +383,7 @@ def copy( select_columns.extend(other_columns) query = insert(cls).from_select( - [*modified_columns, *other_columns], select(select_columns).where(cls.id.in_(test_definition_ids)) + [*modified_columns, *other_columns], select(*select_columns).where(cls.id.in_(test_definition_ids)) ) db_session = get_current_session() db_session.execute(query) diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index 1517bb4e..053d9cdc 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -138,8 +138,8 @@ def get_minimal(cls, run_id: str | UUID) -> TestRunMinimal | None: if not is_uuid4(run_id): return None - query = select(cls._minimal_columns).join(TestSuite).where(cls.id == run_id) - result = get_current_session().execute(query).first() + query = select(*cls._minimal_columns).join(TestSuite).where(cls.id == run_id) + result = get_current_session().execute(query).mappings().first() return TestRunMinimal(**result) if result else None @classmethod @@ -151,7 +151,7 @@ def get_latest_run(cls, project_code: str) -> LatestTestRun | None: .order_by(desc(TestRun.test_starttime)) .limit(1) ) - result = get_current_session().execute(query).first() + result = get_current_session().execute(query).mappings().first() if result: return LatestTestRun(str(result["id"]), result["test_starttime"]) return None @@ -330,7 +330,7 @@ def get_monitoring_summary(self, table_name: str | None = None) -> TestRunMonito .group_by(*group_by) ) - return TestRunMonitorSummary(**get_current_session().execute(query).first()) + return TestRunMonitorSummary(**get_current_session().execute(query).mappings().first()) @classmethod def has_running_process(cls, ids: list[str]) -> bool: diff --git a/testgen/common/notifications/test_run.py b/testgen/common/notifications/test_run.py index 337e43dd..aa309263 100644 --- a/testgen/common/notifications/test_run.py +++ b/testgen/common/notifications/test_run.py @@ -319,7 +319,7 @@ def send_test_run_notifications(test_run: TestRun, result_list_ct=20, result_sta .limit(result_count_by_status[status]) ) - result_list_by_status[status] = [{**r} for r in get_current_session().execute(query)] + result_list_by_status[status] = [{**r._mapping} for r in get_current_session().execute(query)] tr_summary, = TestRun.select_summary(test_run_ids=[test_run.id]) diff --git a/testgen/common/process_service.py b/testgen/common/process_service.py index b37460ee..a7917649 100644 --- a/testgen/common/process_service.py +++ b/testgen/common/process_service.py @@ -13,35 +13,42 @@ def get_current_process_id(): def kill_profile_run(process_id): - keywords = ["/dk/bin/testgen", "run-profile"] - status, message = kill_process(process_id, keywords) + status, message = kill_process(process_id, subcommand="run-profile") return status, message def kill_test_run(process_id): - keywords = ["/dk/bin/testgen", "run-tests"] - status, message = kill_process(process_id, keywords) + status, message = kill_process(process_id, subcommand="run-tests") return status, message -def kill_process(process_id, keywords=None): +def _is_testgen_process(process) -> bool: + """A process is ours if any cmdline argument references the testgen entry point. + + The executable name varies by platform (e.g. macOS reports "Python" for the + framework binary, Linux "python3.13", Docker "testgen") so we match on the + command line instead. + """ + return any("testgen" in arg.lower() for arg in process.cmdline()) + + +def kill_process(process_id, subcommand: str | None = None): if settings.IS_DEBUG: msg = "Cannot kill processes in debug mode (threads are used instead of new process)" LOG.warn(msg) return False, msg try: process = psutil.Process(process_id) - if process.name().lower() not in ["testgen", "python3"]: - message = f"The process was not killed because the process_id {process_id} is not a testgen process. Details: {process.name()}" + cmdline = process.cmdline() + if not _is_testgen_process(process): + message = f"The process was not killed because the process_id {process_id} is not a testgen process. Details: {process.name()} {cmdline}" LOG.error(f"kill_process: {message}") return False, message - if keywords: - for keyword in keywords: - if keyword.lower() not in process.cmdline(): - message = f"The process was not killed because the keyword {keyword} was not found. Details: {process.cmdline()}" - LOG.error(f"kill_process: {message}") - return False, message + if subcommand and subcommand not in cmdline: + message = f"The process was not killed because the subcommand {subcommand} was not found. Details: {cmdline}" + LOG.error(f"kill_process: {message}") + return False, message process.terminate() process.wait(timeout=10) diff --git a/testgen/common/read_file.py b/testgen/common/read_file.py index 8e49fb8d..ada6a86a 100644 --- a/testgen/common/read_file.py +++ b/testgen/common/read_file.py @@ -4,7 +4,7 @@ import re from collections.abc import Generator from functools import cache -from importlib.abc import Traversable +from importlib.resources.abc import Traversable from importlib.resources import as_file, files import yaml diff --git a/testgen/ui/components/frontend/js/components/score_breakdown.js b/testgen/ui/components/frontend/js/components/score_breakdown.js index acd2ffe1..fe3b2c53 100644 --- a/testgen/ui/components/frontend/js/components/score_breakdown.js +++ b/testgen/ui/components/frontend/js/components/score_breakdown.js @@ -156,10 +156,14 @@ const IssueCountCell = (value, row, score, category, scoreType, onViewDetails) = drilldown = `${row.table_groups_id}.${row.table_name}.${row.column_name}`; } + // Hide View for rows where the grouping value is null/empty — drilldown filtering + // needs a non-empty value on the backend and router, so the link would dead-end. + const canDrillDown = value && drilldown && onViewDetails; + return div( { class: 'flex-row', style: `flex: ${BREAKDOWN_COLUMNS_SIZES.issue_ct}`, 'data-testid': 'score-breakdown-cell' }, span({ class: 'mr-2', style: 'min-width: 40px;' }, value || '-'), - (value && onViewDetails) + canDrillDown ? div( { class: 'flex-row clickable', diff --git a/testgen/ui/pdf/dataframe_table.py b/testgen/ui/pdf/dataframe_table.py index 2516444e..9f4966e8 100644 --- a/testgen/ui/pdf/dataframe_table.py +++ b/testgen/ui/pdf/dataframe_table.py @@ -1,8 +1,5 @@ from collections.abc import Iterable -from math import nan - import pandas -from numpy import NaN from pandas.core.dtypes.common import is_numeric_dtype from reportlab.lib import colors, enums from reportlab.lib.styles import ParagraphStyle @@ -271,7 +268,7 @@ def _convert_col_values(self, col): def _convert_value(value): if isinstance(value, Paragraph): return value - elif value in (None, NaN, nan): + elif pandas.isna(value): return self.null_para else: return Paragraph(str(value), para_style) diff --git a/testgen/ui/services/database_service.py b/testgen/ui/services/database_service.py index 8877a423..98d6b251 100644 --- a/testgen/ui/services/database_service.py +++ b/testgen/ui/services/database_service.py @@ -12,7 +12,7 @@ from typing import Any from sqlalchemy import text -from sqlalchemy.engine import Row, RowMapping +from sqlalchemy.engine import RowMapping from sqlalchemy.engine.cursor import CursorResult from testgen.common.database.database_service import get_flavor_service @@ -54,7 +54,7 @@ def fetch_one_from_db(query: str, params: dict | None = None) -> RowMapping | No return result._mapping if result else None -def fetch_from_target_db(connection: Connection, query: str, params: dict | None = None) -> list[Row]: +def fetch_from_target_db(connection: Connection, query: str, params: dict | None = None) -> list[RowMapping]: connection_params = connection.to_dict() flavor_service = get_flavor_service(connection.sql_flavor) resolved = resolve_connection_params(connection_params) @@ -64,4 +64,4 @@ def fetch_from_target_db(connection: Connection, query: str, params: dict | None for pre_query, pre_params in flavor_service.get_pre_connection_queries(resolved): conn.execute(text(pre_query), pre_params) cursor: CursorResult = conn.execute(text(query), params) - return cursor.fetchall() + return cursor.mappings().fetchall() diff --git a/testgen/ui/static/js/components/score_breakdown.js b/testgen/ui/static/js/components/score_breakdown.js index acd2ffe1..fe3b2c53 100644 --- a/testgen/ui/static/js/components/score_breakdown.js +++ b/testgen/ui/static/js/components/score_breakdown.js @@ -156,10 +156,14 @@ const IssueCountCell = (value, row, score, category, scoreType, onViewDetails) = drilldown = `${row.table_groups_id}.${row.table_name}.${row.column_name}`; } + // Hide View for rows where the grouping value is null/empty — drilldown filtering + // needs a non-empty value on the backend and router, so the link would dead-end. + const canDrillDown = value && drilldown && onViewDetails; + return div( { class: 'flex-row', style: `flex: ${BREAKDOWN_COLUMNS_SIZES.issue_ct}`, 'data-testid': 'score-breakdown-cell' }, span({ class: 'mr-2', style: 'min-width: 40px;' }, value || '-'), - (value && onViewDetails) + canDrillDown ? div( { class: 'flex-row clickable', diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index c0089fca..efbd28a6 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -105,7 +105,13 @@ def on_save_connection_clicked(updated_connection): if len(url_parts) > 1: updated_connection["url"] = url_parts[1] - if updated_connection.get("connect_by_key"): + # Databricks OAuth sets connect_by_key but stores the Client Secret in project_pw_encrypted, + # so it follows the password path rather than the private-key path. + uses_private_key = ( + updated_connection.get("connect_by_key") + and updated_connection.get("sql_flavor_code") != "databricks" + ) + if uses_private_key: updated_connection["project_pw_encrypted"] = "" if is_pristine(updated_connection.get("private_key_passphrase")): del updated_connection["private_key_passphrase"] @@ -242,7 +248,7 @@ def test_connection(self, connection: Connection) -> "ConnectionStatus": try: flavor_service = get_flavor_service(connection.sql_flavor) results = db.fetch_from_target_db(connection, flavor_service.test_query) - connection_successful = len(results) == 1 and results[0][0] == 1 + connection_successful = len(results) == 1 and next(iter(results[0].values())) == 1 if not connection_successful: return ConnectionStatus(message="Error completing a query to the database server.", successful=False) diff --git a/tests/unit/common/notifications/test_test_run_notifications.py b/tests/unit/common/notifications/test_test_run_notifications.py index 06cd75f9..4855dc58 100644 --- a/tests/unit/common/notifications/test_test_run_notifications.py +++ b/tests/unit/common/notifications/test_test_run_notifications.py @@ -141,8 +141,14 @@ def test_send_test_run_notification( error_ct=error_ct, ) + # SA 2.0 Row objects expose ._mapping; mock them accordingly + def _make_row(): + r = Mock() + r._mapping = {} + return r + db_session_mock.execute.side_effect = [ - [{} for _ in range(ct)] + [_make_row() for _ in range(ct)] for ct in (failed_expected, warning_expected, error_expected) if ct > 0 ]