Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pathwaysutils/elastic/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def validate(self) -> Set[int]:
_logger.debug(
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
)
if not is_error_due_to_slice_down(error):
if not is_error_due_to_slice_down(error, log_traceback=False):
raise
return active_slice_indices

Expand Down Expand Up @@ -249,7 +249,9 @@ def wait_for_slices(
time.sleep(time_to_sleep)


def is_error_due_to_slice_down(error: Exception) -> bool:
def is_error_due_to_slice_down(
error: Exception, log_traceback: bool = True
) -> bool:
"""Returns True if the error is due to slice down.

The error types that are considered due to slice down are
Expand All @@ -261,6 +263,7 @@ def is_error_due_to_slice_down(error: Exception) -> bool:

Args:
error: The error to check.
log_traceback: If True, log the traceback of the error.
"""
error_due_to_slice_down = False
traceback_logging_level = logging.DEBUG
Expand Down Expand Up @@ -293,6 +296,7 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
if not error_due_to_slice_down:
_logger.debug("Caught an error not due to slice down")

_logger.log(traceback_logging_level, "Error details:", exc_info=True)
if log_traceback:
_logger.log(traceback_logging_level, "Error details:", exc_info=True)

return error_due_to_slice_down
227 changes: 138 additions & 89 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ class Manager:
all_slice_indices: Set[int]
active_slice_indices: Set[int]
new_slice_event: threading.Event

available_inactive_slices: Set[int]
_stop_event: threading.Event | None
_monitor_thread: threading.Thread | None
def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
"""Initializes the manager.

Expand All @@ -113,6 +115,45 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
slice_to_devices=self.slice_to_devices
)
self.new_slice_event = threading.Event()
self.available_inactive_slices = set()

self._stop_event = None
self._monitor_thread = None

def start_monitoring(self, poll_interval: float | int = 10) -> None:
"""Starts the background monitor thread.

Args:
poll_interval: The number of seconds to wait between activity checks.
"""
if self._monitor_thread is not None and self._monitor_thread.is_alive():
_logger.warning("Monitor thread is already running.")
return

self._stop_event = threading.Event()
self._monitor_thread = threading.Thread(
target=self._monitor_new_slices,
args=(self._stop_event, poll_interval),
daemon=True,
)
self._monitor_thread.start()
_logger.info("Elastic monitor thread started with interval %s.", poll_interval)

def close(self) -> None:
"""Stops the background monitor thread."""
if self._stop_event is not None:
self._stop_event.set()
if self._monitor_thread is not None:
_logger.info("Closing manager, waiting for monitor thread to stop...")
try:
self._monitor_thread.join(timeout=5)
except RuntimeError as e:
if "cannot join thread" in str(e):
pass
else:
raise
self._monitor_thread = None
self._stop_event = None

@functools.cached_property
def total_slice_count(self) -> int:
Expand Down Expand Up @@ -171,37 +212,58 @@ def _cleanup_on_retry(self):
for array in jax.live_arrays():
array.delete()

def _monitor_new_slices(
self, stop_event: threading.Event, poll_interval: float | int
) -> None:
"""Monitors for new slices and sets the `new_slice_event` if found."""
while not stop_event.wait(poll_interval):
try:
if not self.inactive_slice_indices:
_logger.debug("No inactive slices to check.")
continue
def _check_inactive_slices(self) -> None:
"""Checks inactive slices and updates available_inactive_slices."""
if not self.inactive_slice_indices:
_logger.debug("No inactive slices to check.")
if self.available_inactive_slices:
self.available_inactive_slices.clear()
self.new_slice_event.clear()
return

_logger.debug(
"Now checking inactive slices %s", self.inactive_slice_indices
)
inactive_slice_to_devices = {
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
}
found_slices = elastic.get_active_slice_indices(
inactive_slice_to_devices
)

_logger.debug(
"Checking inactive slices: %s", self.inactive_slice_indices
)
inactive_slice_to_devices = {
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
}
newly_active_indices = elastic.get_active_slice_indices(
inactive_slice_to_devices
)
_logger.debug(
"Found available and inactive slices %s", found_slices
)

if newly_active_indices:
_logger.info(
"New slices found: %s. Setting new slice event.",
newly_active_indices,
)
self.new_slice_event.set()
return
if found_slices != self.available_inactive_slices:
_logger.info(
"Newly available but inactive slices %s", found_slices
)
self.available_inactive_slices = found_slices
if self.available_inactive_slices:
self.new_slice_event.set()
else:
self.new_slice_event.clear()

_logger.debug("No new slices found.")
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error in monitor thread")
def _monitor_new_slices(
self, stop_event: threading.Event, poll_interval: float | int
) -> None:
"""Monitors for new slices and updates available_inactive_slices."""
_logger.info("Elastic monitor thread started.")
try:
while not stop_event.wait(poll_interval):
try:
self._check_inactive_slices()
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error in monitor thread loop")
except BaseException as e:
_logger.critical(
"Catastrophic error in monitor thread, thread is dying!",
exc_info=True,
)
raise
finally:
_logger.info("Elastic monitor thread stopped.")

def elastic_retry(
self,
Expand Down Expand Up @@ -286,6 +348,7 @@ def elastic_retry(
def decorator(func: _F) -> _F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
self.start_monitoring(poll_interval)

def attempt_execution(attempt: int) -> Any:
_logger.info("Elastic attempt %d", attempt)
Expand All @@ -295,73 +358,59 @@ def attempt_execution(attempt: int) -> Any:
poll_interval=poll_interval,
timeout=timeout,
)
self.available_inactive_slices.clear()
if pre_callback is not None:
pre_callback()

with jax.default_device(self.default_device):
self.new_slice_event.clear()
stop_event = threading.Event()

if target_slice_count < self.total_slice_count:
monitor_thread = threading.Thread(
target=self._monitor_new_slices,
args=(stop_event, poll_interval),
daemon=True,
)
monitor_thread.start()
else:
monitor_thread = None
return func(*args, **kwargs)

def handle_scale_up_error(attempt: int, error: ScaleUpSignalError) -> None:
_logger.info("Scale up requested.")
_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()

if not retry_policy(attempt, error):
_logger.info("Retry policy rejected retry after ScaleUpSignalError.")
raise ElasticRuntimeError(f"Elastic attempt {attempt} failed.") from error

_logger.info("Retrying.")

def handle_slice_down_error(attempt: int, error: jax.errors.JaxRuntimeError) -> None:
if not elastic.is_error_due_to_slice_down(error):
raise

if self.new_slice_event.is_set():
_logger.info("Slice down event and new slice available detected.")
else:
_logger.info("Slice down event detected.")

_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()

if not retry_policy(attempt, error):
_logger.info("Retry policy rejected retry after JaxRuntimeError.")
raise ElasticRuntimeError(f"Elastic attempt {attempt} failed.") from error

_logger.info("Retrying.")

try:
attempt = 1
while True:
try:
return func(*args, **kwargs)
finally:
stop_event.set()
if monitor_thread is not None:
monitor_thread.join()

attempt = 1
while True:
try:
return attempt_execution(attempt)
except ScaleUpSignalError as error:
_logger.info("Scale up requested.")
_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()

if not retry_policy(attempt, error):
_logger.info(
"Retry policy rejected retry after ScaleUpSignalError."
)
raise ElasticRuntimeError(
f"Elastic attempt {attempt} failed."
) from error

_logger.info("Retrying.")
except jax.errors.JaxRuntimeError as error:
if not elastic.is_error_due_to_slice_down(error):
raise

if self.new_slice_event.is_set():
_logger.info("Slice down event and new slice available detected.")
else:
_logger.info("Slice down event detected.")

_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()

if not retry_policy(attempt, error):
_logger.info("Retry policy rejected retry after JaxRuntimeError.")
raise ElasticRuntimeError(
f"Elastic attempt {attempt} failed."
) from error

_logger.info("Retrying.")

attempt += 1
return attempt_execution(attempt)
except ScaleUpSignalError as error:
handle_scale_up_error(attempt, error)
except jax.errors.JaxRuntimeError as error:
handle_slice_down_error(attempt, error)
attempt += 1
finally:
self.close()

return wrapper

Expand Down
Loading