From 8f56d78f50405c8656a28289513efcb48d8202dc Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 11 Apr 2026 12:27:01 -0400 Subject: [PATCH] improve main.py sandbox-mode test coverage moving some shared utilities to test/utils.py --- test/pytests/test_main.py | 53 ++++++++++++++++++++++++++++ test/pytests/test_main_regression.py | 20 +---------- test/utils.py | 19 ++++++++++ 3 files changed, 73 insertions(+), 19 deletions(-) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 7411e21e..3f511851 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -13,6 +13,7 @@ import click from click.testing import CliRunner +import pymysql from pymysql.err import OperationalError import pytest @@ -38,7 +39,9 @@ TEMPFILE_PREFIX, USER, DummyFormatter, + DummyLogger, FakeCursorBase, + RecordingSQLExecute, ReusableLock, call_click_entrypoint_direct, dbtest, @@ -2365,3 +2368,53 @@ def test_get_last_query_returns_latest_query() -> None: cli.query_history = [main.Query('select 1', True, False)] assert main.MyCli.get_last_query(cli) == 'select 1' + + +def test_connect_reports_expired_password_login_error(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.config = {'connection': {}, 'main': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class ExpiredPasswordSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [pymysql.OperationalError(main.ER_MUST_CHANGE_PASSWORD_LOGIN, 'must change password')] + + monkeypatch.setattr(main, 'SQLExecute', ExpiredPasswordSQLExecute) + + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db', port=3307) + + assert any('password has expired' in message for message in echo_calls) + + +def test_connect_sets_cli_sandbox_mode_when_sqlexecute_enters_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.config = {'connection': {}, 'main': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class SandboxSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.sandbox_mode = True + + monkeypatch.setattr(main, 'SQLExecute', SandboxSQLExecute) + + main.MyCli.connect(cli, host='db', port=3307) + + assert cli.sandbox_mode is True + assert any('password has expired' in message for message in echo_calls) diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index bdb106c7..f4dfc62c 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -38,6 +38,7 @@ DummyFormatter, DummyLogger, FakeCursorBase, + RecordingSQLExecute, call_click_entrypoint_direct, make_bare_mycli, make_dummy_mycli_class, @@ -60,25 +61,6 @@ def as_bool(self, key: str) -> bool: return str(self[key]).lower() == 'true' -class RecordingSQLExecute: - calls: list[dict[str, Any]] = [] - side_effects: list[Any] = [] - - def __init__(self, **kwargs: Any) -> None: - type(self).calls.append(dict(kwargs)) - if type(self).side_effects: - effect = type(self).side_effects.pop(0) - if isinstance(effect, BaseException): - raise effect - if callable(effect): - effect(kwargs) - self.kwargs = kwargs - self.dbname = kwargs.get('database') - self.user = kwargs.get('user') - self.conn = kwargs.get('conn') - self.sandbox_mode = False - - class ToggleBool: def __init__(self, values: list[bool]) -> None: self.values = values diff --git a/test/utils.py b/test/utils.py index 7c43af5c..1d01ac33 100644 --- a/test/utils.py +++ b/test/utils.py @@ -101,6 +101,25 @@ def __iter__(self) -> Iterator[tuple[Any, ...]]: return iter(self._rows) +class RecordingSQLExecute: + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + def __init__(self, **kwargs: Any) -> None: + type(self).calls.append(dict(kwargs)) + if type(self).side_effects: + effect = type(self).side_effects.pop(0) + if isinstance(effect, BaseException): + raise effect + if callable(effect): + effect(kwargs) + self.kwargs = kwargs + self.dbname = kwargs.get('database') + self.user = kwargs.get('user') + self.conn = kwargs.get('conn') + self.sandbox_mode = False + + def make_bare_mycli() -> Any: cli = object.__new__(main.MyCli) cli.logger = cast(Any, DummyLogger())