diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py index 1529e1e..1d45575 100644 --- a/pathwaysutils/elastic/elastic.py +++ b/pathwaysutils/elastic/elastic.py @@ -121,7 +121,7 @@ def validate(self) -> Set[int]: _logger.debug( "Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error ) - if not is_error_due_to_slice_down(error): + if not is_error_due_to_slice_down(error, log_traceback=False): raise return active_slice_indices @@ -249,7 +249,9 @@ def wait_for_slices( time.sleep(time_to_sleep) -def is_error_due_to_slice_down(error: Exception) -> bool: +def is_error_due_to_slice_down( + error: Exception, log_traceback: bool = True +) -> bool: """Returns True if the error is due to slice down. The error types that are considered due to slice down are @@ -261,6 +263,7 @@ def is_error_due_to_slice_down(error: Exception) -> bool: Args: error: The error to check. + log_traceback: If True, log the traceback of the error. """ error_due_to_slice_down = False traceback_logging_level = logging.DEBUG @@ -293,6 +296,7 @@ def is_error_due_to_slice_down(error: Exception) -> bool: if not error_due_to_slice_down: _logger.debug("Caught an error not due to slice down") - _logger.log(traceback_logging_level, "Error details:", exc_info=True) + if log_traceback: + _logger.log(traceback_logging_level, "Error details:", exc_info=True) return error_due_to_slice_down diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ccb8927..c4bce22 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -96,7 +96,9 @@ class Manager: all_slice_indices: Set[int] active_slice_indices: Set[int] new_slice_event: threading.Event - + available_inactive_slices: Set[int] + _stop_event: threading.Event | None + _monitor_thread: threading.Thread | None def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """Initializes the manager. @@ -113,6 +115,45 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: slice_to_devices=self.slice_to_devices ) self.new_slice_event = threading.Event() + self.available_inactive_slices = set() + + self._stop_event = None + self._monitor_thread = None + + def start_monitoring(self, poll_interval: float | int = 10) -> None: + """Starts the background monitor thread. + + Args: + poll_interval: The number of seconds to wait between activity checks. + """ + if self._monitor_thread is not None and self._monitor_thread.is_alive(): + _logger.warning("Monitor thread is already running.") + return + + self._stop_event = threading.Event() + self._monitor_thread = threading.Thread( + target=self._monitor_new_slices, + args=(self._stop_event, poll_interval), + daemon=True, + ) + self._monitor_thread.start() + _logger.info("Elastic monitor thread started with interval %s.", poll_interval) + + def close(self) -> None: + """Stops the background monitor thread.""" + if self._stop_event is not None: + self._stop_event.set() + if self._monitor_thread is not None: + _logger.info("Closing manager, waiting for monitor thread to stop...") + try: + self._monitor_thread.join(timeout=5) + except RuntimeError as e: + if "cannot join thread" in str(e): + pass + else: + raise + self._monitor_thread = None + self._stop_event = None @functools.cached_property def total_slice_count(self) -> int: @@ -171,37 +212,58 @@ def _cleanup_on_retry(self): for array in jax.live_arrays(): array.delete() - def _monitor_new_slices( - self, stop_event: threading.Event, poll_interval: float | int - ) -> None: - """Monitors for new slices and sets the `new_slice_event` if found.""" - while not stop_event.wait(poll_interval): - try: - if not self.inactive_slice_indices: - _logger.debug("No inactive slices to check.") - continue + def _check_inactive_slices(self) -> None: + """Checks inactive slices and updates available_inactive_slices.""" + if not self.inactive_slice_indices: + _logger.debug("No inactive slices to check.") + if self.available_inactive_slices: + self.available_inactive_slices.clear() + self.new_slice_event.clear() + return + + _logger.debug( + "Now checking inactive slices %s", self.inactive_slice_indices + ) + inactive_slice_to_devices = { + i: self.slice_to_devices[i] for i in self.inactive_slice_indices + } + found_slices = elastic.get_active_slice_indices( + inactive_slice_to_devices + ) - _logger.debug( - "Checking inactive slices: %s", self.inactive_slice_indices - ) - inactive_slice_to_devices = { - i: self.slice_to_devices[i] for i in self.inactive_slice_indices - } - newly_active_indices = elastic.get_active_slice_indices( - inactive_slice_to_devices - ) + _logger.debug( + "Found available and inactive slices %s", found_slices + ) - if newly_active_indices: - _logger.info( - "New slices found: %s. Setting new slice event.", - newly_active_indices, - ) - self.new_slice_event.set() - return + if found_slices != self.available_inactive_slices: + _logger.info( + "Newly available but inactive slices %s", found_slices + ) + self.available_inactive_slices = found_slices + if self.available_inactive_slices: + self.new_slice_event.set() + else: + self.new_slice_event.clear() - _logger.debug("No new slices found.") - except Exception: # pylint: disable=broad-exception-caught - _logger.exception("Error in monitor thread") + def _monitor_new_slices( + self, stop_event: threading.Event, poll_interval: float | int + ) -> None: + """Monitors for new slices and updates available_inactive_slices.""" + _logger.info("Elastic monitor thread started.") + try: + while not stop_event.wait(poll_interval): + try: + self._check_inactive_slices() + except Exception: # pylint: disable=broad-exception-caught + _logger.exception("Error in monitor thread loop") + except BaseException as e: + _logger.critical( + "Catastrophic error in monitor thread, thread is dying!", + exc_info=True, + ) + raise + finally: + _logger.info("Elastic monitor thread stopped.") def elastic_retry( self, @@ -286,6 +348,7 @@ def elastic_retry( def decorator(func: _F) -> _F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: + self.start_monitoring(poll_interval) def attempt_execution(attempt: int) -> Any: _logger.info("Elastic attempt %d", attempt) @@ -295,73 +358,59 @@ def attempt_execution(attempt: int) -> Any: poll_interval=poll_interval, timeout=timeout, ) + self.available_inactive_slices.clear() if pre_callback is not None: pre_callback() with jax.default_device(self.default_device): self.new_slice_event.clear() - stop_event = threading.Event() - - if target_slice_count < self.total_slice_count: - monitor_thread = threading.Thread( - target=self._monitor_new_slices, - args=(stop_event, poll_interval), - daemon=True, - ) - monitor_thread.start() - else: - monitor_thread = None + return func(*args, **kwargs) + + def handle_scale_up_error(attempt: int, error: ScaleUpSignalError) -> None: + _logger.info("Scale up requested.") + _elastic_event_cleanup() + + if on_elastic_event_callback is not None: + on_elastic_event_callback() + + if not retry_policy(attempt, error): + _logger.info("Retry policy rejected retry after ScaleUpSignalError.") + raise ElasticRuntimeError(f"Elastic attempt {attempt} failed.") from error + + _logger.info("Retrying.") + + def handle_slice_down_error(attempt: int, error: jax.errors.JaxRuntimeError) -> None: + if not elastic.is_error_due_to_slice_down(error): + raise + + if self.new_slice_event.is_set(): + _logger.info("Slice down event and new slice available detected.") + else: + _logger.info("Slice down event detected.") + + _elastic_event_cleanup() + + if on_elastic_event_callback is not None: + on_elastic_event_callback() + + if not retry_policy(attempt, error): + _logger.info("Retry policy rejected retry after JaxRuntimeError.") + raise ElasticRuntimeError(f"Elastic attempt {attempt} failed.") from error + + _logger.info("Retrying.") + try: + attempt = 1 + while True: try: - return func(*args, **kwargs) - finally: - stop_event.set() - if monitor_thread is not None: - monitor_thread.join() - - attempt = 1 - while True: - try: - return attempt_execution(attempt) - except ScaleUpSignalError as error: - _logger.info("Scale up requested.") - _elastic_event_cleanup() - - if on_elastic_event_callback is not None: - on_elastic_event_callback() - - if not retry_policy(attempt, error): - _logger.info( - "Retry policy rejected retry after ScaleUpSignalError." - ) - raise ElasticRuntimeError( - f"Elastic attempt {attempt} failed." - ) from error - - _logger.info("Retrying.") - except jax.errors.JaxRuntimeError as error: - if not elastic.is_error_due_to_slice_down(error): - raise - - if self.new_slice_event.is_set(): - _logger.info("Slice down event and new slice available detected.") - else: - _logger.info("Slice down event detected.") - - _elastic_event_cleanup() - - if on_elastic_event_callback is not None: - on_elastic_event_callback() - - if not retry_policy(attempt, error): - _logger.info("Retry policy rejected retry after JaxRuntimeError.") - raise ElasticRuntimeError( - f"Elastic attempt {attempt} failed." - ) from error - - _logger.info("Retrying.") - - attempt += 1 + return attempt_execution(attempt) + except ScaleUpSignalError as error: + handle_scale_up_error(attempt, error) + except jax.errors.JaxRuntimeError as error: + handle_slice_down_error(attempt, error) + attempt += 1 + finally: + self.close() return wrapper