diff --git a/src/kernel_ci_cloud_labs/providers/aws_provider.py b/src/kernel_ci_cloud_labs/providers/aws_provider.py index 472f070..0c72c02 100644 --- a/src/kernel_ci_cloud_labs/providers/aws_provider.py +++ b/src/kernel_ci_cloud_labs/providers/aws_provider.py @@ -6,6 +6,10 @@ # providers/aws_provider.py +import os +import re +import time + from kernel_ci_cloud_labs.core.base_provider import BaseProvider from kernel_ci_cloud_labs.core.logging_config import get_logger from kernel_ci_cloud_labs.core.registry import register_provider @@ -13,6 +17,40 @@ logger = get_logger(__name__) +# Kernel-side crash / stall patterns matched against each new CloudWatch +# event from the guest VM console. First hit triggers an early abort of +# wait_for_task_completion so the kernelci-api node is finished +# incomplete/Infrastructure rather than the loop spinning for an hour on a +# wedged guest. +_KERNEL_CRASH_PATTERNS = tuple( + re.compile(p) for p in ( + # Fatal traps -- guest will normally exit, but if qemu wedges it won't. + r"Kernel panic - not syncing", + r"\bOops\s*:", + r"\bBUG\s*:", + r"general protection fault", + r"unable to handle kernel paging request", + r"double fault", + r"Internal error\s*:", # arm / arm64 die() banner + # Stalls / hangs -- kernel is still scheduling but wedged. + r"watchdog: BUG: soft lockup", + r"soft lockup - CPU#", + r"rcu_(?:sched|preempt|bh) detected stalls", + r"INFO: task .* blocked for more than", + ) +) + + +def _scan_for_kernel_crash(events): + """Return the first event whose message matches a crash pattern, else None.""" + for event in events: + message = event.get("message") or "" + for pat in _KERNEL_CRASH_PATTERNS: + if pat.search(message): + return event + return None + + @register_provider("aws") class AWSProvider(BaseProvider): """AWS provider for running containers on Fargate.""" @@ -215,18 +253,68 @@ def wait_for_running(self, timeout=300): logger.warning("✗ Task failed to reach RUNNING state: %s", e) return False - def wait_for_task_completion(self): + def _build_vm_log_manager(self, start_time_ms): + """Build an AWSCloudWatchManager scoped to this run's VM console group. + + Returns None when no EC2 log group / run prefix is configured (e.g. + unit tests or a deployment without the cloudwatch section); in that + case wait_for_task_completion falls back to pure status polling + without crash detection. """ - Wait for the ECS task to complete (reach STOPPED state). + cw_log_groups = self.config.get("cloudwatch", {}).get("log_groups", {}) or {} + ec2_log_group = next((k for k in cw_log_groups if "/ec2/" in k), None) + run_prefix = self.config.get("run_prefix") + if not ec2_log_group or not run_prefix: + logger.debug( + "VM crash detection disabled: ec2_log_group=%s run_prefix=%s", + ec2_log_group, run_prefix, + ) + return None + try: + logs_client = self.auth.get_client("logs") + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Could not obtain CloudWatch logs client (%s) — " + "VM crash detection disabled", e, + ) + return None + if not logs_client: + logger.warning("No CloudWatch logs client — VM crash detection disabled") + return None + from kernel_ci_cloud_labs.auth.aws_cloudwatch_manager import ( # noqa: PLC0415 + AWSCloudWatchManager, + ) + return AWSCloudWatchManager( + logs_client, + {}, + run_prefix=run_prefix, + start_time_ms=start_time_ms, + ec2_log_group=ec2_log_group, + ) - This uses boto3's tasks_stopped waiter which polls the task status - until it transitions to STOPPED, meaning: - - All containers have finished executing - - All VMs have been spawned, run tests, and uploaded results to S3 - - The task has gracefully shut down + def wait_for_task_completion(self): + """Wait for the ECS task to reach STOPPED, with crash / stall detection. + + Polls task status until STOPPED. While waiting, tails the per-run VM + console log group ({EC2_LOG_GROUP}/{run_prefix} -- written by SSM Run + Command, see launch_vm.py) and aborts early on: + + * Kernel-side crash patterns in the guest console (panic, Oops, BUG:, + soft lockup, RCU stall, hung task, GP fault, kernel paging fault). + The ECS task is stopped and a RuntimeError is raised so the caller + finishes the kernelci-api node incomplete/Infrastructure with the + matched line surfaced in error_msg. + * No new VM console output for PULLAB_TASK_HANG_THRESHOLD_SEC seconds + (default 600) -- silent stall, same treatment as a crash. + * Overall PULLAB_TASK_WAIT_TIMEOUT_SEC seconds elapsed (default 3600) + -- final safety net for whatever isn't covered above. + + Crash detection requires both cloudwatch.log_groups (with an /ec2/ + group) and run_prefix in the run config; otherwise the loop falls + back to plain status polling. Returns: - dict: Final task status including exit codes + dict: Final task status including exit codes. """ if not self.task_arn: logger.error("Cannot wait for completion - no task ARN available") @@ -235,39 +323,90 @@ def wait_for_task_completion(self): logger.info("Waiting for task to complete...") logger.debug("Task ARN: %s", self.task_arn) - # Poll task status with periodic INFO logging so the user sees progress - import time as _time + poll_interval = float(os.getenv("PULLAB_TASK_POLL_INTERVAL_SEC") or 30) + log_interval = float(os.getenv("PULLAB_TASK_PROGRESS_LOG_SEC") or 120) + hang_threshold = float(os.getenv("PULLAB_TASK_HANG_THRESHOLD_SEC") or 600) + overall_timeout = float(os.getenv("PULLAB_TASK_WAIT_TIMEOUT_SEC") or 3600) - poll_interval = 30 # seconds between status checks - log_interval = 120 # seconds between INFO progress messages - start = _time.time() + start = time.time() + start_ms = int(start * 1000) last_log_time = start + last_event_seen_at = start + # filter_log_events startTime is inclusive; -1 so the first poll + # picks up events with timestamp == start_ms. + last_event_ms = start_ms - 1 + + cw_manager = self._build_vm_log_manager(start_ms) + if cw_manager is None: + logger.info("Wait loop: VM crash detection disabled (no log group / run_prefix)") while True: + elapsed = time.time() - start + if elapsed > overall_timeout: + logger.error( + "Overall wait timeout (%ds) exceeded — stopping task", + int(overall_timeout), + ) + self.terminate_container() + raise RuntimeError( + f"task wait timeout exceeded after {int(elapsed)}s" + ) + status = self.get_task_status() if not status: logger.warning("Could not retrieve task status, retrying...") - _time.sleep(poll_interval) + time.sleep(poll_interval) continue task_status = status.get("status", "UNKNOWN") - elapsed = int(_time.time() - start) if task_status == "STOPPED": - logger.info("✓ Task completed (elapsed: %dm %ds)", elapsed // 60, elapsed % 60) + logger.info( + "✓ Task completed (elapsed: %dm %ds)", + int(elapsed) // 60, int(elapsed) % 60, + ) break - now = _time.time() + # Tail the VM console group for crash patterns / progress. + if cw_manager is not None: + new_events = cw_manager.get_logs_with_filter( + start_time=last_event_ms + 1 + ) or [] + if new_events: + last_event_seen_at = time.time() + for ev in new_events: + ts = ev.get("timestamp", 0) + if ts > last_event_ms: + last_event_ms = ts + hit = _scan_for_kernel_crash(new_events) + if hit: + msg = (hit.get("message") or "").strip()[:300] + logger.error( + "Kernel crash/stall in VM console (stream=%s): %s", + hit.get("logStreamName", "?"), msg, + ) + self.terminate_container() + raise RuntimeError(f"kernel crash detected in VM: {msg}") + elif (time.time() - last_event_seen_at) > hang_threshold: + logger.error( + "No VM console output for %ds (hang threshold %ds) — stopping task", + int(time.time() - last_event_seen_at), + int(hang_threshold), + ) + self.terminate_container() + raise RuntimeError( + f"no VM console output for {int(hang_threshold)}s" + ) + + now = time.time() if now - last_log_time >= log_interval: last_log_time = now logger.info( "Task still running... (status: %s, elapsed: %dm %ds)", - task_status, - elapsed // 60, - elapsed % 60, + task_status, int(elapsed) // 60, int(elapsed) % 60, ) - _time.sleep(poll_interval) + time.sleep(poll_interval) # Get final status to check exit codes final_status = self.get_task_status() diff --git a/tests/test_provider_lifecycle.py b/tests/test_provider_lifecycle.py index d914cbc..26276aa 100644 --- a/tests/test_provider_lifecycle.py +++ b/tests/test_provider_lifecycle.py @@ -5,11 +5,15 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock +from types import SimpleNamespace +from unittest.mock import Mock, patch import pytest -from kernel_ci_cloud_labs.providers.aws_provider import AWSProvider +from kernel_ci_cloud_labs.providers.aws_provider import ( + AWSProvider, + _scan_for_kernel_crash, +) class TestAWSProviderLifecycle: @@ -83,3 +87,152 @@ def test_stop_all_tasks_stops_running_tasks(self): provider.stop_all_tasks() assert mock_ecs.stop_task.call_count == 2 + + +class TestAWSProviderWaitLoop: + """wait_for_task_completion: crash / stall / overall-timeout detection.""" + + @staticmethod + def _make_provider(monkeypatch, env=None): + # Deterministic clock: time advances only when sleep is called. This + # lets the wait loop spin synchronously while elapsed time advances + # in fixed steps -- a hang threshold of 5s reliably fires in ~6 polls. + clock = {"t": 1000.0} + + def fake_time(): + return clock["t"] + + def fake_sleep(_): + clock["t"] += 1.0 + + fake_time_mod = SimpleNamespace(time=fake_time, sleep=fake_sleep) + monkeypatch.setattr( + "kernel_ci_cloud_labs.providers.aws_provider.time", fake_time_mod, + ) + for k, v in (env or {}).items(): + monkeypatch.setenv(k, v) + + mock_auth = Mock() + mock_ecs = Mock() + mock_auth.get_client.return_value = mock_ecs + config = { + "ecs": {"cluster_name": "c", "task_definition": {"family": "t"}}, + } + provider = AWSProvider(mock_auth, config) + provider.authenticate() + provider.task_arn = "arn:aws:ecs:::task/c/abc" + return provider, mock_ecs + + def test_clean_run_returns_when_stopped(self, monkeypatch): + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_PROGRESS_LOG_SEC": "9999", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + # Two RUNNING then STOPPED inside the loop + one STOPPED for the + # post-loop final_status call. + statuses = iter([ + {"status": "RUNNING", "containers": []}, + {"status": "RUNNING", "containers": []}, + {"status": "STOPPED", "containers": []}, + {"status": "STOPPED", "containers": []}, + ]) + with patch.object(p, "get_task_status", side_effect=lambda: next(statuses)), \ + patch.object(p, "_build_vm_log_manager", return_value=None): + result = p.wait_for_task_completion() + assert result["status"] == "STOPPED" + mock_ecs.stop_task.assert_not_called() + + def test_kernel_crash_terminates_and_raises(self, monkeypatch): + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + cw_manager = Mock() + cw_manager.get_logs_with_filter.return_value = [ + { + "timestamp": 1001000, + "message": "Kernel panic - not syncing: VFS: Unable to mount root fs", + "logStreamName": "cmd-id/i-12345/stdout", + }, + ] + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=cw_manager): + with pytest.raises(RuntimeError, match="kernel crash detected"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + def test_hang_threshold_terminates_and_raises(self, monkeypatch): + # fake_sleep advances 1s/iter; 5s threshold fires after ~6 polls. + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "5", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "9999", + }) + cw_manager = Mock() + cw_manager.get_logs_with_filter.return_value = [] + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=cw_manager): + with pytest.raises(RuntimeError, match="no VM console output"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + def test_overall_timeout_terminates_and_raises(self, monkeypatch): + # No log manager so the only abort path is the overall-timeout cap. + p, mock_ecs = self._make_provider(monkeypatch, env={ + "PULLAB_TASK_POLL_INTERVAL_SEC": "0", + "PULLAB_TASK_HANG_THRESHOLD_SEC": "9999", + "PULLAB_TASK_WAIT_TIMEOUT_SEC": "3", + }) + with patch.object(p, "get_task_status", return_value={"status": "RUNNING", "containers": []}), \ + patch.object(p, "_build_vm_log_manager", return_value=None): + with pytest.raises(RuntimeError, match="task wait timeout exceeded"): + p.wait_for_task_completion() + mock_ecs.stop_task.assert_called_once() + + +class TestScanForKernelCrash: + """Pure-helper matcher: kernel-side crash / stall patterns.""" + + @pytest.mark.parametrize("message", [ + "Kernel panic - not syncing: VFS: Unable to mount root fs", + "Oops: 0000 [#1] SMP PTI", + "BUG: kernel NULL pointer dereference, address: 0000000000000000", + "watchdog: BUG: soft lockup - CPU#0 stuck for 22s!", + "soft lockup - CPU#3 stuck", + "INFO: rcu_sched detected stalls on CPUs/tasks:", + "INFO: task kworker/0:1:42 blocked for more than 120 seconds.", + "general protection fault: 0000 [#1] PREEMPT SMP", + "unable to handle kernel paging request at ffff800010000000", + "Internal error: Oops: 96000005 [#1] SMP", + ]) + def test_matches_known_crash_patterns(self, message): + hit = _scan_for_kernel_crash([{"message": message}]) + assert hit is not None + assert hit["message"] == message + + def test_no_match_returns_none(self): + # Bare "Call Trace:" is too noisy to be a crash on its own -- we + # deliberately don't match it; verify it doesn't trip the matcher. + assert _scan_for_kernel_crash([ + {"message": "Booting Linux..."}, + {"message": "systemd: started kernel-ci-runner"}, + {"message": "Call Trace:"}, + ]) is None + + def test_returns_first_hit(self): + events = [ + {"message": "Booting"}, + {"message": "Kernel panic - not syncing"}, + {"message": "Oops:"}, + ] + hit = _scan_for_kernel_crash(events) + assert hit["message"] == "Kernel panic - not syncing" + + def test_handles_missing_or_none_message(self): + # Real CloudWatch events sometimes carry no message field. + assert _scan_for_kernel_crash( + [{}, {"message": None}, {"message": ""}] + ) is None