From 0910b3d0f7610581e5c40f8b210d2b57e98840c1 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 16 Mar 2026 20:35:21 +0200 Subject: [PATCH 1/9] feat: implement PR commit message retrieval and validation in commit-check --- main.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 7491710..71ab225 100755 --- a/main.py +++ b/main.py @@ -36,6 +36,30 @@ def log_env_vars(): print(f"PR_COMMENTS = {PR_COMMENTS}\n") +def get_pr_commit_messages() -> list[str]: + """Get all commit messages for the current PR (pull_request event only). + + In a pull_request event, actions/checkout checks out a synthetic merge + commit (HEAD = merge of PR branch into base). HEAD^1 is the base branch + tip, HEAD^2 is the PR branch tip. So HEAD^1..HEAD^2 gives all PR commits. + """ + if os.getenv("GITHUB_EVENT_NAME", "") != "pull_request": + return [] + try: + result = subprocess.run( + ["git", "log", "--pretty=format:%B%x00", "HEAD^1..HEAD^2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return [m.strip() for m in result.stdout.split("\x00") if m.strip()] + except Exception: + pass + return [] + + def run_commit_check() -> int: """Runs the commit-check command and logs the result.""" args = [ @@ -58,9 +82,39 @@ def run_commit_check() -> int: if value == "true" ] - command = ["commit-check"] + args - print(" ".join(command)) + total_rc = 0 with open("result.txt", "w") as result_file: + if MESSAGE == "true": + pr_messages = get_pr_commit_messages() + if pr_messages: + # In PR context: check each commit message individually to avoid + # only validating the synthetic merge commit at HEAD. + for msg in pr_messages: + result = subprocess.run( + ["commit-check", "--message"], + input=msg, + stdout=result_file, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + total_rc += result.returncode + + # Run non-message checks (branch, author) once + other_args = [a for a in args if a != "--message"] + if other_args: + command = ["commit-check"] + other_args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + total_rc += result.returncode + + return total_rc + + # Non-PR context or message disabled: run all checks at once + command = ["commit-check"] + args + print(" ".join(command)) result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) From 8a8c1f6422eb04506a46f04403b23f8016226b5a Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 16 Mar 2026 20:41:57 +0200 Subject: [PATCH 2/9] fix: correct variable names in subprocess calls for clarity --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 71ab225..846518c 100755 --- a/main.py +++ b/main.py @@ -105,20 +105,20 @@ def run_commit_check() -> int: if other_args: command = ["commit-check"] + other_args print(" ".join(command)) - result = subprocess.run( + other_result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) - total_rc += result.returncode + total_rc += other_result.returncode return total_rc # Non-PR context or message disabled: run all checks at once command = ["commit-check"] + args print(" ".join(command)) - result = subprocess.run( + default_result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) - return result.returncode + return default_result.returncode def read_result_file() -> str | None: From 235d3975180c298576cd1927b77d44c332554765 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 01:26:10 +0200 Subject: [PATCH 3/9] fix: get original commit message content in get_pr_commit_messages() (#190) --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 846518c..aa4bd88 100755 --- a/main.py +++ b/main.py @@ -54,7 +54,7 @@ def get_pr_commit_messages() -> list[str]: check=False, ) if result.returncode == 0 and result.stdout: - return [m.strip() for m in result.stdout.split("\x00") if m.strip()] + return [m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n")] except Exception: pass return [] From 44e900c44acb56f8fdd6eee1039f051d4b4419b2 Mon Sep 17 00:00:00 2001 From: Xianpeng Shen Date: Tue, 17 Mar 2026 01:38:30 +0200 Subject: [PATCH 4/9] chore: Update commit-check version to 2.4.3 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c1cf58e..2473c51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Install commit-check CLI # For details please see: https://github.com/commit-check/commit-check -commit-check==2.4.2 +commit-check==2.4.3 # Interact with the GitHub API. PyGithub==2.8.1 From cb56efe463648caabf3a64bff3599a9f8ae7d1e4 Mon Sep 17 00:00:00 2001 From: Xianpeng Shen Date: Tue, 17 Mar 2026 02:06:08 +0200 Subject: [PATCH 5/9] feat: Enable autofix for pull requests in pre-commit config --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 571d5ee..4eee43c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ # https://pre-commit.com/ ci: + autofix_prs: true autofix_commit_msg: 'ci: auto fixes from pre-commit.com hooks' autoupdate_commit_msg: 'ci: pre-commit autoupdate' From 85f23c7b24cb0b7adf4a9e8872b0e3ccaa66de0f Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:23:46 +0200 Subject: [PATCH 6/9] feat: refactor commit-check logic and add unit tests for new functionality --- main.py | 102 +++++----- main_test.py | 538 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 593 insertions(+), 47 deletions(-) create mode 100644 main_test.py diff --git a/main.py b/main.py index aa4bd88..bbcd7f8 100755 --- a/main.py +++ b/main.py @@ -60,28 +60,59 @@ def get_pr_commit_messages() -> list[str]: return [] -def run_commit_check() -> int: - """Runs the commit-check command and logs the result.""" - args = [ - "--message", - "--branch", - "--author-name", - "--author-email", - ] - args = [ - arg - for arg, value in zip( - args, - [ - MESSAGE, - BRANCH, - AUTHOR_NAME, - AUTHOR_EMAIL, - ], +def build_check_args( + message: str, branch: str, author_name: str, author_email: str +) -> list[str]: + """Maps 'true'/'false' flag values to CLI argument list.""" + flags = ["--message", "--branch", "--author-name", "--author-email"] + values = [message, branch, author_name, author_email] + return [flag for flag, value in zip(flags, values) if value == "true"] + + +def run_pr_message_checks(pr_messages: list[str], result_file) -> int: # type: ignore[type-arg] + """Checks each PR commit message individually via commit-check --message. + + Returns cumulative returncode across all messages. + """ + total_rc = 0 + for msg in pr_messages: + result = subprocess.run( + ["commit-check", "--message"], + input=msg, + stdout=result_file, + stderr=subprocess.PIPE, + text=True, + check=False, ) - if value == "true" - ] + total_rc += result.returncode + return total_rc + +def run_other_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] + """Runs non-message checks (branch, author) once. Returns 0 if args is empty.""" + if not args: + return 0 + command = ["commit-check"] + args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + return result.returncode + + +def run_default_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] + """Runs all checks at once (non-PR context or message disabled).""" + command = ["commit-check"] + args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + return result.returncode + + +def run_commit_check() -> int: + """Runs the commit-check command and logs the result.""" + args = build_check_args(MESSAGE, BRANCH, AUTHOR_NAME, AUTHOR_EMAIL) total_rc = 0 with open("result.txt", "w") as result_file: if MESSAGE == "true": @@ -89,36 +120,13 @@ def run_commit_check() -> int: if pr_messages: # In PR context: check each commit message individually to avoid # only validating the synthetic merge commit at HEAD. - for msg in pr_messages: - result = subprocess.run( - ["commit-check", "--message"], - input=msg, - stdout=result_file, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - total_rc += result.returncode - - # Run non-message checks (branch, author) once + total_rc += run_pr_message_checks(pr_messages, result_file) other_args = [a for a in args if a != "--message"] - if other_args: - command = ["commit-check"] + other_args - print(" ".join(command)) - other_result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False - ) - total_rc += other_result.returncode - + total_rc += run_other_checks(other_args, result_file) return total_rc - # Non-PR context or message disabled: run all checks at once - command = ["commit-check"] + args - print(" ".join(command)) - default_result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False - ) - return default_result.returncode + total_rc += run_default_checks(args, result_file) + return total_rc def read_result_file() -> str | None: diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..85a108c --- /dev/null +++ b/main_test.py @@ -0,0 +1,538 @@ +"""Unit tests for main.py""" +import io +import json +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +# GITHUB_STEP_SUMMARY is accessed via os.environ[] (not getenv) at import time, +# so we must set it before importing main. +os.environ.setdefault("GITHUB_STEP_SUMMARY", "/tmp/step_summary.txt") + +import main # noqa: E402 + + +class TestBuildCheckArgs(unittest.TestCase): + def test_all_true(self): + result = main.build_check_args("true", "true", "true", "true") + self.assertEqual(result, ["--message", "--branch", "--author-name", "--author-email"]) + + def test_all_false(self): + result = main.build_check_args("false", "false", "false", "false") + self.assertEqual(result, []) + + def test_message_only(self): + result = main.build_check_args("true", "false", "false", "false") + self.assertEqual(result, ["--message"]) + + def test_branch_only(self): + result = main.build_check_args("false", "true", "false", "false") + self.assertEqual(result, ["--branch"]) + + def test_author_name_and_email(self): + result = main.build_check_args("false", "false", "true", "true") + self.assertEqual(result, ["--author-name", "--author-email"]) + + def test_message_and_branch(self): + result = main.build_check_args("true", "true", "false", "false") + self.assertEqual(result, ["--message", "--branch"]) + + +class TestRunPrMessageChecks(unittest.TestCase): + def _make_file(self): + return io.StringIO() + + def test_single_message_pass(self): + mock_result = MagicMock() + mock_result.returncode = 0 + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_pr_message_checks(["fix: something"], self._make_file()) + self.assertEqual(rc, 0) + mock_run.assert_called_once() + call_kwargs = mock_run.call_args + self.assertIn("--message", call_kwargs[0][0]) + self.assertEqual(call_kwargs[1]["input"], "fix: something") + + def test_single_message_fail(self): + mock_result = MagicMock() + mock_result.returncode = 1 + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_pr_message_checks(["bad commit"], self._make_file()) + self.assertEqual(rc, 1) + + def test_multiple_messages_partial_failure(self): + results = [MagicMock(returncode=0), MagicMock(returncode=1), MagicMock(returncode=0)] + with patch("main.subprocess.run", side_effect=results): + rc = main.run_pr_message_checks(["ok", "bad", "ok"], self._make_file()) + self.assertEqual(rc, 1) + + def test_multiple_messages_all_fail(self): + results = [MagicMock(returncode=1), MagicMock(returncode=1)] + with patch("main.subprocess.run", side_effect=results): + rc = main.run_pr_message_checks(["bad1", "bad2"], self._make_file()) + self.assertEqual(rc, 2) + + def test_empty_list(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_pr_message_checks([], self._make_file()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + +class TestRunOtherChecks(unittest.TestCase): + def test_empty_args_returns_zero(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_other_checks([], io.StringIO()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + def test_with_args_calls_subprocess(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_other_checks(["--branch"], io.StringIO()) + self.assertEqual(rc, 0) + called_cmd = mock_run.call_args[0][0] + self.assertEqual(called_cmd, ["commit-check", "--branch"]) + + def test_with_args_returns_returncode(self): + mock_result = MagicMock(returncode=1) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_other_checks(["--branch", "--author-name"], io.StringIO()) + self.assertEqual(rc, 1) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_other_checks(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") + + +class TestRunDefaultChecks(unittest.TestCase): + def test_rc_zero(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_default_checks(["--message", "--branch"], io.StringIO()) + self.assertEqual(rc, 0) + + def test_rc_one(self): + mock_result = MagicMock(returncode=1) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_default_checks(["--message"], io.StringIO()) + self.assertEqual(rc, 1) + + def test_command_contains_all_args(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_default_checks(["--message", "--branch", "--author-name"], io.StringIO()) + called_cmd = mock_run.call_args[0][0] + self.assertEqual( + called_cmd, + ["commit-check", "--message", "--branch", "--author-name"], + ) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_default_checks(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") + + def test_empty_args(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_default_checks([], io.StringIO()) + called_cmd = mock_run.call_args[0][0] + self.assertEqual(called_cmd, ["commit-check"]) + + +class TestRunCommitCheck(unittest.TestCase): + def setUp(self): + # Ensure result.txt is written to a temp location + self._orig_dir = os.getcwd() + import tempfile + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_pr_path_calls_pr_message_checks(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["fix: something"]), + patch("main.run_pr_message_checks", return_value=0) as mock_pr, + patch("main.run_other_checks", return_value=0), + patch("main.run_default_checks") as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_called_once() + mock_default.assert_not_called() + self.assertEqual(rc, 0) + + def test_pr_path_rc_accumulation(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["bad msg"]), + patch("main.run_pr_message_checks", return_value=2), + patch("main.run_other_checks", return_value=1), + ): + rc = main.run_commit_check() + self.assertEqual(rc, 3) + + def test_non_pr_path_uses_default_checks(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=[]), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_default_checks", return_value=0) as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_not_called() + mock_default.assert_called_once() + self.assertEqual(rc, 0) + + def test_message_false_uses_default_checks(self): + with ( + patch("main.MESSAGE", "false"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_default_checks", return_value=0) as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_not_called() + mock_default.assert_called_once() + self.assertEqual(rc, 0) + + def test_result_txt_is_created(self): + with ( + patch("main.MESSAGE", "false"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.run_default_checks", return_value=0), + ): + main.run_commit_check() + self.assertTrue(os.path.exists(os.path.join(self._tmpdir, "result.txt"))) + + def test_other_args_excludes_message(self): + """When in PR path, run_other_checks must not receive --message.""" + captured_args = [] + + def fake_other_checks(args, result_file): + captured_args.extend(args) + return 0 + + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["fix: x"]), + patch("main.run_pr_message_checks", return_value=0), + patch("main.run_other_checks", side_effect=fake_other_checks), + ): + main.run_commit_check() + self.assertNotIn("--message", captured_args) + self.assertIn("--branch", captured_args) + + +class TestGetPrCommitMessages(unittest.TestCase): + def test_non_pr_event_returns_empty(self): + with patch.dict(os.environ, {"GITHUB_EVENT_NAME": "push"}): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_pr_event_with_commits(self): + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "fix: first\n\x00feat: second\n\x00" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + + def test_pr_event_empty_output(self): + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_git_failure_returns_empty(self): + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stdout = "" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_exception_returns_empty(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", side_effect=Exception("git not found")), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + +class TestReadResultFile(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def _write_result(self, content: str): + with open("result.txt", "w") as f: + f.write(content) + + def test_empty_file_returns_none(self): + self._write_result("") + result = main.read_result_file() + self.assertIsNone(result) + + def test_file_with_content(self): + self._write_result("some output\n") + result = main.read_result_file() + self.assertEqual(result, "some output") + + def test_ansi_codes_are_stripped(self): + self._write_result("\x1B[31mError\x1B[0m: bad commit") + result = main.read_result_file() + self.assertEqual(result, "Error: bad commit") + + def test_trailing_whitespace_stripped(self): + self._write_result("output\n\n") + result = main.read_result_file() + self.assertEqual(result, "output") + + +class TestAddJobSummary(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + # Create an empty result.txt + open("result.txt", "w").close() + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_false_skips(self): + with patch("main.JOB_SUMMARY", "false"): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + + def test_success_writes_success_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY", "true"), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value=None), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + with open(summary_path) as f: + content = f.read() + self.assertIn(main.SUCCESS_TITLE, content) + + def test_failure_writes_failure_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY", "true"), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value="bad commit message"), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 1) + with open(summary_path) as f: + content = f.read() + self.assertIn(main.FAILURE_TITLE, content) + self.assertIn("bad commit message", content) + + +class TestIsForkPr(unittest.TestCase): + def test_no_event_path(self): + with patch.dict(os.environ, {}, clear=True): + # Remove GITHUB_EVENT_PATH if present + os.environ.pop("GITHUB_EVENT_PATH", None) + result = main.is_fork_pr() + self.assertFalse(result) + + def test_same_repo_not_fork(self): + import tempfile + event = { + "pull_request": { + "head": {"repo": {"full_name": "owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(event, f) + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertFalse(result) + os.unlink(event_path) + + def test_different_repo_is_fork(self): + import tempfile + event = { + "pull_request": { + "head": {"repo": {"full_name": "fork-owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(event, f) + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertTrue(result) + os.unlink(event_path) + + def test_json_parse_failure_returns_false(self): + import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not valid json{{{") + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertFalse(result) + os.unlink(event_path) + + +class TestLogErrorAndExit(unittest.TestCase): + def test_exits_with_specified_code(self): + with self.assertRaises(SystemExit) as ctx: + main.log_error_and_exit("# Title", None, 0) + self.assertEqual(ctx.exception.code, 0) + + def test_exits_with_nonzero_code(self): + with self.assertRaises(SystemExit) as ctx: + main.log_error_and_exit("# Title", None, 2) + self.assertEqual(ctx.exception.code, 2) + + def test_with_result_text_prints_error(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", "bad commit", 1) + mock_print.assert_called_once() + printed = mock_print.call_args[0][0] + self.assertIn("::error::", printed) + self.assertIn("bad commit", printed) + + def test_without_result_text_no_print(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", None, 1) + mock_print.assert_not_called() + + def test_empty_string_result_text_no_print(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", "", 1) + mock_print.assert_not_called() + + +class TestMain(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + open("result.txt", "w").close() + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_success_path(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=0), + patch("main.add_job_summary", return_value=0), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + def test_failure_path(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=0), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value="bad msg"), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 1) + + def test_dry_run_forces_zero(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=1), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "true"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + def test_all_subfunctions_called(self): + with ( + patch("main.log_env_vars") as mock_log, + patch("main.run_commit_check", return_value=0) as mock_run, + patch("main.add_job_summary", return_value=0) as mock_summary, + patch("main.add_pr_comments", return_value=0) as mock_comments, + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit), + ): + main.main() + mock_log.assert_called_once() + mock_run.assert_called_once() + mock_summary.assert_called_once() + mock_comments.assert_called_once() + + +if __name__ == "__main__": + unittest.main() From 8d8615e8035706434b60600e912c07f74482fa61 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:24:30 +0200 Subject: [PATCH 7/9] fix: format list comprehensions for better readability in get_pr_commit_messages() and related tests --- main.py | 4 +++- main_test.py | 22 +++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index bbcd7f8..0bd15b7 100755 --- a/main.py +++ b/main.py @@ -54,7 +54,9 @@ def get_pr_commit_messages() -> list[str]: check=False, ) if result.returncode == 0 and result.stdout: - return [m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n")] + return [ + m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n") + ] except Exception: pass return [] diff --git a/main_test.py b/main_test.py index 85a108c..b2a7525 100644 --- a/main_test.py +++ b/main_test.py @@ -1,4 +1,5 @@ """Unit tests for main.py""" + import io import json import os @@ -16,7 +17,9 @@ class TestBuildCheckArgs(unittest.TestCase): def test_all_true(self): result = main.build_check_args("true", "true", "true", "true") - self.assertEqual(result, ["--message", "--branch", "--author-name", "--author-email"]) + self.assertEqual( + result, ["--message", "--branch", "--author-name", "--author-email"] + ) def test_all_false(self): result = main.build_check_args("false", "false", "false", "false") @@ -62,7 +65,11 @@ def test_single_message_fail(self): self.assertEqual(rc, 1) def test_multiple_messages_partial_failure(self): - results = [MagicMock(returncode=0), MagicMock(returncode=1), MagicMock(returncode=0)] + results = [ + MagicMock(returncode=0), + MagicMock(returncode=1), + MagicMock(returncode=0), + ] with patch("main.subprocess.run", side_effect=results): rc = main.run_pr_message_checks(["ok", "bad", "ok"], self._make_file()) self.assertEqual(rc, 1) @@ -125,7 +132,9 @@ def test_rc_one(self): def test_command_contains_all_args(self): mock_result = MagicMock(returncode=0) with patch("main.subprocess.run", return_value=mock_result) as mock_run: - main.run_default_checks(["--message", "--branch", "--author-name"], io.StringIO()) + main.run_default_checks( + ["--message", "--branch", "--author-name"], io.StringIO() + ) called_cmd = mock_run.call_args[0][0] self.assertEqual( called_cmd, @@ -152,6 +161,7 @@ def setUp(self): # Ensure result.txt is written to a temp location self._orig_dir = os.getcwd() import tempfile + self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -300,6 +310,7 @@ def test_exception_returns_empty(self): class TestReadResultFile(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -335,6 +346,7 @@ def test_trailing_whitespace_stripped(self): class TestAddJobSummary(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -387,6 +399,7 @@ def test_no_event_path(self): def test_same_repo_not_fork(self): import tempfile + event = { "pull_request": { "head": {"repo": {"full_name": "owner/repo"}}, @@ -403,6 +416,7 @@ def test_same_repo_not_fork(self): def test_different_repo_is_fork(self): import tempfile + event = { "pull_request": { "head": {"repo": {"full_name": "fork-owner/repo"}}, @@ -419,6 +433,7 @@ def test_different_repo_is_fork(self): def test_json_parse_failure_returns_false(self): import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write("not valid json{{{") event_path = f.name @@ -470,6 +485,7 @@ def test_empty_string_result_text_no_print(self): class TestMain(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) From 83fd88ba9d295e05f5ebf3dabaad0cd1fc90a905 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:25:32 +0200 Subject: [PATCH 8/9] fix: add args to codespell hook to ignore specific words --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4eee43c..dad2476 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,3 +26,4 @@ repos: rev: v2.3.0 hooks: - id: codespell + args: [--ignore-words-list=assertin] From 39d41e3f105a3db17447c6c34f02b02ac0ced50e Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:39:17 +0200 Subject: [PATCH 9/9] fix: update main_test.py --- main_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/main_test.py b/main_test.py index b2a7525..13670ec 100644 --- a/main_test.py +++ b/main_test.py @@ -319,7 +319,7 @@ def tearDown(self): os.chdir(self._orig_dir) def _write_result(self, content: str): - with open("result.txt", "w") as f: + with open("result.txt", "w", encoding="utf-8") as f: f.write(content) def test_empty_file_returns_none(self): @@ -351,7 +351,8 @@ def setUp(self): self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) # Create an empty result.txt - open("result.txt", "w").close() + with open("result.txt", "w", encoding="utf-8"): + pass def tearDown(self): os.chdir(self._orig_dir) @@ -370,7 +371,7 @@ def test_success_writes_success_title(self): ): rc = main.add_job_summary() self.assertEqual(rc, 0) - with open(summary_path) as f: + with open(summary_path, encoding="utf-8") as f: content = f.read() self.assertIn(main.SUCCESS_TITLE, content) @@ -383,7 +384,7 @@ def test_failure_writes_failure_title(self): ): rc = main.add_job_summary() self.assertEqual(rc, 1) - with open(summary_path) as f: + with open(summary_path, encoding="utf-8") as f: content = f.read() self.assertIn(main.FAILURE_TITLE, content) self.assertIn("bad commit message", content) @@ -489,7 +490,8 @@ def setUp(self): self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) - open("result.txt", "w").close() + with open("result.txt", "w", encoding="utf-8"): + pass def tearDown(self): os.chdir(self._orig_dir)