Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 160 additions & 21 deletions src/kernel_ci_cloud_labs/providers/aws_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,51 @@


# providers/aws_provider.py
import os
import re
import time

from kernel_ci_cloud_labs.core.base_provider import BaseProvider
from kernel_ci_cloud_labs.core.logging_config import get_logger
from kernel_ci_cloud_labs.core.registry import register_provider

logger = get_logger(__name__)


# Kernel-side crash / stall patterns matched against each new CloudWatch
# event from the guest VM console. First hit triggers an early abort of
# wait_for_task_completion so the kernelci-api node is finished
# incomplete/Infrastructure rather than the loop spinning for an hour on a
# wedged guest.
_KERNEL_CRASH_PATTERNS = tuple(
re.compile(p) for p in (
# Fatal traps -- guest will normally exit, but if qemu wedges it won't.
r"Kernel panic - not syncing",
r"\bOops\s*:",
r"\bBUG\s*:",
r"general protection fault",
r"unable to handle kernel paging request",
r"double fault",
r"Internal error\s*:", # arm / arm64 die() banner
# Stalls / hangs -- kernel is still scheduling but wedged.
r"watchdog: BUG: soft lockup",
r"soft lockup - CPU#",
r"rcu_(?:sched|preempt|bh) detected stalls",
r"INFO: task .* blocked for more than",
)
)


def _scan_for_kernel_crash(events):
"""Return the first event whose message matches a crash pattern, else None."""
for event in events:
message = event.get("message") or ""
for pat in _KERNEL_CRASH_PATTERNS:
if pat.search(message):
return event
return None


@register_provider("aws")
class AWSProvider(BaseProvider):
"""AWS provider for running containers on Fargate."""
Expand Down Expand Up @@ -215,18 +253,68 @@ def wait_for_running(self, timeout=300):
logger.warning("✗ Task failed to reach RUNNING state: %s", e)
return False

def wait_for_task_completion(self):
def _build_vm_log_manager(self, start_time_ms):
"""Build an AWSCloudWatchManager scoped to this run's VM console group.

Returns None when no EC2 log group / run prefix is configured (e.g.
unit tests or a deployment without the cloudwatch section); in that
case wait_for_task_completion falls back to pure status polling
without crash detection.
"""
Wait for the ECS task to complete (reach STOPPED state).
cw_log_groups = self.config.get("cloudwatch", {}).get("log_groups", {}) or {}
ec2_log_group = next((k for k in cw_log_groups if "/ec2/" in k), None)
run_prefix = self.config.get("run_prefix")
if not ec2_log_group or not run_prefix:
logger.debug(
"VM crash detection disabled: ec2_log_group=%s run_prefix=%s",
ec2_log_group, run_prefix,
)
return None
try:
logs_client = self.auth.get_client("logs")
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(
"Could not obtain CloudWatch logs client (%s) — "
"VM crash detection disabled", e,
)
return None
if not logs_client:
logger.warning("No CloudWatch logs client — VM crash detection disabled")
return None
from kernel_ci_cloud_labs.auth.aws_cloudwatch_manager import ( # noqa: PLC0415
AWSCloudWatchManager,
)
return AWSCloudWatchManager(
logs_client,
{},
run_prefix=run_prefix,
start_time_ms=start_time_ms,
ec2_log_group=ec2_log_group,
)

This uses boto3's tasks_stopped waiter which polls the task status
until it transitions to STOPPED, meaning:
- All containers have finished executing
- All VMs have been spawned, run tests, and uploaded results to S3
- The task has gracefully shut down
def wait_for_task_completion(self):
"""Wait for the ECS task to reach STOPPED, with crash / stall detection.

Polls task status until STOPPED. While waiting, tails the per-run VM
console log group ({EC2_LOG_GROUP}/{run_prefix} -- written by SSM Run
Command, see launch_vm.py) and aborts early on:

* Kernel-side crash patterns in the guest console (panic, Oops, BUG:,
soft lockup, RCU stall, hung task, GP fault, kernel paging fault).
The ECS task is stopped and a RuntimeError is raised so the caller
finishes the kernelci-api node incomplete/Infrastructure with the
matched line surfaced in error_msg.
* No new VM console output for PULLAB_TASK_HANG_THRESHOLD_SEC seconds
(default 600) -- silent stall, same treatment as a crash.
* Overall PULLAB_TASK_WAIT_TIMEOUT_SEC seconds elapsed (default 3600)
-- final safety net for whatever isn't covered above.

Crash detection requires both cloudwatch.log_groups (with an /ec2/
group) and run_prefix in the run config; otherwise the loop falls
back to plain status polling.

Returns:
dict: Final task status including exit codes
dict: Final task status including exit codes.
"""
if not self.task_arn:
logger.error("Cannot wait for completion - no task ARN available")
Expand All @@ -235,39 +323,90 @@ def wait_for_task_completion(self):
logger.info("Waiting for task to complete...")
logger.debug("Task ARN: %s", self.task_arn)

# Poll task status with periodic INFO logging so the user sees progress
import time as _time
poll_interval = float(os.getenv("PULLAB_TASK_POLL_INTERVAL_SEC") or 30)
log_interval = float(os.getenv("PULLAB_TASK_PROGRESS_LOG_SEC") or 120)
hang_threshold = float(os.getenv("PULLAB_TASK_HANG_THRESHOLD_SEC") or 600)
overall_timeout = float(os.getenv("PULLAB_TASK_WAIT_TIMEOUT_SEC") or 3600)

poll_interval = 30 # seconds between status checks
log_interval = 120 # seconds between INFO progress messages
start = _time.time()
start = time.time()
start_ms = int(start * 1000)
last_log_time = start
last_event_seen_at = start
# filter_log_events startTime is inclusive; -1 so the first poll
# picks up events with timestamp == start_ms.
last_event_ms = start_ms - 1

cw_manager = self._build_vm_log_manager(start_ms)
if cw_manager is None:
logger.info("Wait loop: VM crash detection disabled (no log group / run_prefix)")

while True:
elapsed = time.time() - start
if elapsed > overall_timeout:
logger.error(
"Overall wait timeout (%ds) exceeded — stopping task",
int(overall_timeout),
)
self.terminate_container()
raise RuntimeError(
f"task wait timeout exceeded after {int(elapsed)}s"
)

status = self.get_task_status()
if not status:
logger.warning("Could not retrieve task status, retrying...")
_time.sleep(poll_interval)
time.sleep(poll_interval)
continue

task_status = status.get("status", "UNKNOWN")
elapsed = int(_time.time() - start)

if task_status == "STOPPED":
logger.info("✓ Task completed (elapsed: %dm %ds)", elapsed // 60, elapsed % 60)
logger.info(
"✓ Task completed (elapsed: %dm %ds)",
int(elapsed) // 60, int(elapsed) % 60,
)
break

now = _time.time()
# Tail the VM console group for crash patterns / progress.
if cw_manager is not None:
new_events = cw_manager.get_logs_with_filter(
start_time=last_event_ms + 1
) or []
if new_events:
last_event_seen_at = time.time()
for ev in new_events:
ts = ev.get("timestamp", 0)
if ts > last_event_ms:
last_event_ms = ts
hit = _scan_for_kernel_crash(new_events)
if hit:
msg = (hit.get("message") or "").strip()[:300]
logger.error(
"Kernel crash/stall in VM console (stream=%s): %s",
hit.get("logStreamName", "?"), msg,
)
self.terminate_container()
raise RuntimeError(f"kernel crash detected in VM: {msg}")
elif (time.time() - last_event_seen_at) > hang_threshold:
logger.error(
"No VM console output for %ds (hang threshold %ds) — stopping task",
int(time.time() - last_event_seen_at),
int(hang_threshold),
)
self.terminate_container()
raise RuntimeError(
f"no VM console output for {int(hang_threshold)}s"
)

now = time.time()
if now - last_log_time >= log_interval:
last_log_time = now
logger.info(
"Task still running... (status: %s, elapsed: %dm %ds)",
task_status,
elapsed // 60,
elapsed % 60,
task_status, int(elapsed) // 60, int(elapsed) % 60,
)

_time.sleep(poll_interval)
time.sleep(poll_interval)

# Get final status to check exit codes
final_status = self.get_task_status()
Expand Down
Loading
Loading