From b4050b6c92f3f87fb1b40c69d040969e914343ac Mon Sep 17 00:00:00 2001 From: Denys Fedoryshchenko Date: Sat, 23 May 2026 00:48:26 +0300 Subject: [PATCH] timeout: wait_for_task_completion is no longer an unbounded while True Some VM might be incredibly expensive, so leaving them in running state infinitely can drain whole project budget. So we need to take several measures: 1)Limit running time 2)Try to detect stalled boot (keywords, no log updates, etc) Signed-off-by: Denys Fedoryshchenko --- .../providers/aws_provider.py | 181 ++++++++++++++++-- tests/test_provider_lifecycle.py | 157 ++++++++++++++- 2 files changed, 315 insertions(+), 23 deletions(-) diff --git a/src/kernel_ci_cloud_labs/providers/aws_provider.py b/src/kernel_ci_cloud_labs/providers/aws_provider.py index 472f070..0c72c02 100644 --- a/src/kernel_ci_cloud_labs/providers/aws_provider.py +++ b/src/kernel_ci_cloud_labs/providers/aws_provider.py @@ -6,6 +6,10 @@ # providers/aws_provider.py +import os +import re +import time + from kernel_ci_cloud_labs.core.base_provider import BaseProvider from kernel_ci_cloud_labs.core.logging_config import get_logger from kernel_ci_cloud_labs.core.registry import register_provider @@ -13,6 +17,40 @@ logger = get_logger(__name__) +# Kernel-side crash / stall patterns matched against each new CloudWatch +# event from the guest VM console. First hit triggers an early abort of +# wait_for_task_completion so the kernelci-api node is finished +# incomplete/Infrastructure rather than the loop spinning for an hour on a +# wedged guest. +_KERNEL_CRASH_PATTERNS = tuple( + re.compile(p) for p in ( + # Fatal traps -- guest will normally exit, but if qemu wedges it won't. + r"Kernel panic - not syncing", + r"\bOops\s*:", + r"\bBUG\s*:", + r"general protection fault", + r"unable to handle kernel paging request", + r"double fault", + r"Internal error\s*:", # arm / arm64 die() banner + # Stalls / hangs -- kernel is still scheduling but wedged. + r"watchdog: BUG: soft lockup", + r"soft lockup - CPU#", + r"rcu_(?:sched|preempt|bh) detected stalls", + r"INFO: task .* blocked for more than", + ) +) + + +def _scan_for_kernel_crash(events): + """Return the first event whose message matches a crash pattern, else None.""" + for event in events: + message = event.get("message") or "" + for pat in _KERNEL_CRASH_PATTERNS: + if pat.search(message): + return event + return None + + @register_provider("aws") class AWSProvider(BaseProvider): """AWS provider for running containers on Fargate.""" @@ -215,18 +253,68 @@ def wait_for_running(self, timeout=300): logger.warning("✗ Task failed to reach RUNNING state: %s", e) return False - def wait_for_task_completion(self): + def _build_vm_log_manager(self, start_time_ms): + """Build an AWSCloudWatchManager scoped to this run's VM console group. + + Returns None when no EC2 log group / run prefix is configured (e.g. + unit tests or a deployment without the cloudwatch section); in that + case wait_for_task_completion falls back to pure status polling + without crash detection. """ - Wait for the ECS task to complete (reach STOPPED state). + cw_log_groups = self.config.get("cloudwatch", {}).get("log_groups", {}) or {} + ec2_log_group = next((k for k in cw_log_groups if "/ec2/" in k), None) + run_prefix = self.config.get("run_prefix") + if not ec2_log_group or not run_prefix: + logger.debug( + "VM crash detection disabled: ec2_log_group=%s run_prefix=%s", + ec2_log_group, run_prefix, + ) + return None + try: + logs_client = self.auth.get_client("logs") + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Could not obtain CloudWatch logs client (%s) — " + "VM crash detection disabled", e, + ) + return None + if not logs_client: + logger.warning("No CloudWatch logs client — VM crash detection disabled") + return None + from kernel_ci_cloud_labs.auth.aws_cloudwatch_manager import ( # noqa: PLC0415 + AWSCloudWatchManager, + ) + return AWSCloudWatchManager( + logs_client, + {}, + run_prefix=run_prefix, + start_time_ms=start_time_ms, + ec2_log_group=ec2_log_group, + ) - This uses boto3's tasks_stopped waiter which polls the task status - until it transitions to STOPPED, meaning: - - All containers have finished executing - - All VMs have been spawned, run tests, and uploaded results to S3 - - The task has gracefully shut down + def wait_for_task_completion(self): + """Wait for the ECS task to reach STOPPED, with crash / stall detection. + + Polls task status until STOPPED. While waiting, tails the per-run VM + console log group ({EC2_LOG_GROUP}/{run_prefix} -- written by SSM Run + Command, see launch_vm.py) and aborts early on: + + * Kernel-side crash patterns in the guest console (panic, Oops, BUG:, + soft lockup, RCU stall, hung task, GP fault, kernel paging fault). + The ECS task is stopped and a RuntimeError is raised so the caller + finishes the kernelci-api node incomplete/Infrastructure with the + matched line surfaced in error_msg. + * No new VM console output for PULLAB_TASK_HANG_THRESHOLD_SEC seconds + (default 600) -- silent stall, same treatment as a crash. + * Overall PULLAB_TASK_WAIT_TIMEOUT_SEC seconds elapsed (default 3600) + -- final safety net for whatever isn't covered above. + + Crash detection requires both cloudwatch.log_groups (with an /ec2/ + group) and run_prefix in the run config; otherwise the loop falls + back to plain status polling. Returns: - dict: Final task status including exit codes + dict: Final task status including exit codes. """ if not self.task_arn: logger.error("Cannot wait for completion - no task ARN available") @@ -235,39 +323,90 @@ def wait_for_task_completion(self): logger.info("Waiting for task to complete...") logger.debug("Task ARN: %s", self.task_arn) - # Poll task status with periodic INFO logging so the user sees progress - import time as _time + poll_interval = float(os.getenv("PULLAB_TASK_POLL_INTERVAL_SEC") or 30) + log_interval = float(os.getenv("PULLAB_TASK_PROGRESS_LOG_SEC") or 120) + hang_threshold = float(os.getenv("PULLAB_TASK_HANG_THRESHOLD_SEC") or 600) + overall_timeout = float(os.getenv("PULLAB_TASK_WAIT_TIMEOUT_SEC") or 3600) - poll_interval = 30 # seconds between status checks - log_interval = 120 # seconds between INFO progress messages - start = _time.time() + start = time.time() + start_ms = int(start * 1000) last_log_time = start + last_event_seen_at = start + # filter_log_events startTime is inclusive; -1 so the first poll + # picks up events with timestamp == start_ms. + last_event_ms = start_ms - 1 + + cw_manager = self._build_vm_log_manager(start_ms) + if cw_manager is None: + logger.info("Wait loop: VM crash detection disabled (no log group / run_prefix)") while True: + elapsed = time.time() - start + if elapsed > overall_timeout: + logger.error( + "Overall wait timeout (%ds) exceeded — stopping task", + int(overall_timeout), + ) + self.terminate_container() + raise RuntimeError( + f"task wait timeout exceeded after {int(elapsed)}s" + ) + status = self.get_task_status() if not status: logger.warning("Could not retrieve task status, retrying...") - _time.sleep(poll_interval) + time.sleep(poll_interval) continue task_status = status.get("status", "UNKNOWN") - elapsed = int(_time.time() - start) if task_status == "STOPPED": - logger.info("✓ Task completed (elapsed: %dm %ds)", elapsed // 60, elapsed % 60) + logger.info( + "✓ Task completed (elapsed: %dm %ds)", + int(elapsed) // 60, int(elapsed) % 60, + ) break - now = _time.time() + # Tail the VM console group for crash patterns / progress. + if cw_manager is not None: + new_events = cw_manager.get_logs_with_filter( + start_time=last_event_ms + 1 + ) or [] + if new_events: + last_event_seen_at = time.time() + for ev in new_events: + ts = ev.get("timestamp", 0) + if ts > last_event_ms: + last_event_ms = ts + hit = _scan_for_kernel_crash(new_events) + if hit: + msg = (hit.get("message") or "").strip()[:300] + logger.error( + "Kernel crash/stall in VM console (stream=%s): %s", + hit.get("logStreamName", "?"), msg, + ) + self.terminate_container() + raise RuntimeError(f"kernel crash detected in VM: {msg}") + elif (time.time() - last_event_seen_at) > hang_threshold: + logger.error( + "No VM console output for %ds (hang threshold %ds) — stopping task", + int(time.time() - last_event_seen_at), + int(hang_threshold), + ) + self.terminate_container() + raise RuntimeError( + f"no VM console output for {int(hang_threshold)}s" + ) + + now = time.time() if now - last_log_time >= log_interval: last_log_time = now logger.info( "Task still running... (status: %s, elapsed: %dm %ds)", - task_status, - elapsed // 60, - elapsed % 60, + task_status, int(elapsed) // 60, int(elapsed) % 60, ) - _time.sleep(poll_interval) + time.sleep(poll_interval) # Get final status to check exit codes final_status = self.get_task_status() diff --git a/tests/test_provider_lifecycle.py b/tests/test_provider_lifecycle.py index d914cbc..26276aa 100644 --- a/tests/test_provider_lifecycle.py +++ b/tests/test_provider_lifecycle.py @@ -5,11 +5,15 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock +from types import SimpleNamespace +from unittest.mock import Mock, patch import pytest -from kernel_ci_cloud_labs.providers.aws_provider import AWSProvider +from kernel_ci_cloud_labs.providers.aws_provider import ( + AWSProvider, + _scan_for_kernel_crash, +) class TestAWSProviderLifecycle: @@ -83,3 +87,152 @@ def test_stop_all_tasks_stops_running_tasks(self): provider.stop_all_tasks() assert mock_ecs.stop_task.call_count == 2 + + +class TestAWSProviderWaitLoop: + """wait_for_task_completion: crash / stall / overall-timeout detection.""" + + @staticmethod + def _make_provider(monkeypatch, env=None): + # Deterministic clock: time advances only when sleep is called. This + # lets the wait loop spin synchronously while elapsed time advances + # in fixed steps -- a hang threshold of 5s reliably fires in ~6 polls. + clock = {"t": 1000.0} + + def fake_time(): + return clock["t"] + + def fake_sleep(_): + clock["t"] += 1.0 + + fake_time_mod = SimpleNamespace(time=fake_time, sleep=fake_sleep) + monkeypatch.setattr( + "kernel_ci_cloud_labs.providers.aws_provider.time", fake_time_mod, + ) + for k, v in (env or {}).items(): + monkeypatch.setenv(k, v) + + mock_auth = Mock() + mock_ecs = Mock() + mock_auth.get_client.return_value = mock_ecs + config = { + "ecs": {"cluster_name": "c", "task_definition": {"family": "t"}}, + } + provider = AWSProvider(mock_auth, config) + provider.authenticate() + provider.task_arn = "arn:aws:ecs:::task/c/abc" + return provider, mock_ecs + + def test_clean_run_returns_when_stopped(self, monkeypatch): + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_PROGRESS_LOG_SEC": "9999", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + # Two RUNNING then STOPPED inside the loop + one STOPPED for the + # post-loop final_status call. + statuses = iter([ + {"status": "RUNNING", "containers": []}, + {"status": "RUNNING", "containers": []}, + {"status": "STOPPED", "containers": []}, + {"status": "STOPPED", "containers": []}, + ]) + with patch.object(p, "get_task_status", side_effect=lambda: next(statuses)), \ + patch.object(p, "_build_vm_log_manager", return_value=None): + result = p.wait_for_task_completion() + assert result["status"] == "STOPPED" + mock_ecs.stop_task.assert_not_called() + + def test_kernel_crash_terminates_and_raises(self, monkeypatch): + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + cw_manager = Mock() + cw_manager.get_logs_with_filter.return_value = [ + { + "timestamp": 1001000, + "message": "Kernel panic - not syncing: VFS: Unable to mount root fs", + "logStreamName": "cmd-id/i-12345/stdout", + }, + ] + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=cw_manager): + with pytest.raises(RuntimeError, match="kernel crash detected"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + def test_hang_threshold_terminates_and_raises(self, monkeypatch): + # fake_sleep advances 1s/iter; 5s threshold fires after ~6 polls. + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "5", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + cw_manager = Mock() + cw_manager.get_logs_with_filter.return_value = [] + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=cw_manager): + with pytest.raises(RuntimeError, match="no VM console output"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + def test_overall_timeout_terminates_and_raises(self, monkeypatch): + # No log manager so the only abort path is the overall-timeout cap. + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "3", + }) + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=None): + with pytest.raises(RuntimeError, match="task wait timeout exceeded"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + +class TestScanForKernelCrash: + """Pure-helper matcher: kernel-side crash / stall patterns.""" + + @pytest.mark.parametrize("message", [ + "Kernel panic - not syncing: VFS: Unable to mount root fs", + "Oops: 0000 [#1] SMP PTI", + "BUG: kernel NULL pointer dereference, address: 0000000000000000", + "watchdog: BUG: soft lockup - CPU#0 stuck for 22s!", + "soft lockup - CPU#3 stuck", + "INFO: rcu_sched detected stalls on CPUs/tasks:", + "INFO: task kworker/0:1:42 blocked for more than 120 seconds.", + "general protection fault: 0000 [#1] PREEMPT SMP", + "unable to handle kernel paging request at ffff800010000000", + "Internal error: Oops: 96000005 [#1] SMP", + ]) + def test_matches_known_crash_patterns(self, message): + hit = _scan_for_kernel_crash([{"message": message}]) + assert hit is not None + assert hit["message"] == message + + def test_no_match_returns_none(self): + # Bare "Call Trace:" is too noisy to be a crash on its own -- we + # deliberately don't match it; verify it doesn't trip the matcher. + assert _scan_for_kernel_crash([ + {"message": "Booting Linux..."}, + {"message": "systemd: started kernel-ci-runner"}, + {"message": "Call Trace:"}, + ]) is None + + def test_returns_first_hit(self): + events = [ + {"message": "Booting"}, + {"message": "Kernel panic - not syncing"}, + {"message": "Oops:"}, + ] + hit = _scan_for_kernel_crash(events) + assert hit["message"] == "Kernel panic - not syncing" + + def test_handles_missing_or_none_message(self): + # Real CloudWatch events sometimes carry no message field. + assert _scan_for_kernel_crash( + [{}, {"message": None}, {"message": ""}] + ) is None