From 0dbc86409266225499448836d3c0eb7a6fc7424d Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 3 Feb 2026 00:19:35 -0800 Subject: [PATCH] Refactor manager.py. Extracted functions from the manager class that may be useful without a manager and placed them in elastic.py. Refactored manager_test.py and created elastic_test.py with their respective tests. PiperOrigin-RevId: 864729672 --- pathwaysutils/elastic/elastic.py | 251 ++++++++++++++++++++++++++++ pathwaysutils/elastic/manager.py | 272 ++----------------------------- 2 files changed, 264 insertions(+), 259 deletions(-) create mode 100644 pathwaysutils/elastic/elastic.py diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py new file mode 100644 index 0000000..37029af --- /dev/null +++ b/pathwaysutils/elastic/elastic.py @@ -0,0 +1,251 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Elasticity manager. + +This class provides a utility for elastic training. It provides a decorator that +retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down +events. It also provides a utility for waiting for slices to become active. +""" + +import collections +from collections.abc import Mapping, Sequence +import logging +import time +import traceback + +import jax +import numpy as np +from pathwaysutils.debug import timing + + +_logger = logging.getLogger(__name__) + +_SIMPLE_EXECUTION_TEST_VALUE = 100 +_ELASTIC_DOWN_ERROR_TYPES = [ + "DATA_LOSS", +] +_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ + "DEADLINE_EXCEEDED", + "NOT_FOUND", + "INTERNAL", +] + + +def _plus_one(x: jax.Array) -> jax.Array: + """Adds one to each element in the array. + + Used to test if a slice is active. + + Args: + x: The array to add one to. + + Returns: + The array with one added to each element. + """ + return x + 1 + + +def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array: + """Simple execution to test if a slice is active. + + This function is used to test if a slice is active. It executes a simple + computation on the devices and returns the result. If any of the devices are + not active, the returned array will fail with a JaxRuntimeError used. + + Simply executing this function is not enough to determine if the slice is + active. We also need to check the value of the returned array. + + Args: + devices: The devices to execute on. + + Returns: + The result of the execution. + """ + if not devices: + raise ValueError("No devices") + + test_input = np.zeros(len(devices), dtype=float) + ( + _SIMPLE_EXECUTION_TEST_VALUE - 1 + ) + + return jax.pmap(_plus_one, devices=devices)(test_input) + + +def get_slice_to_devices( + devices: Sequence[jax.Device], +) -> dict[int, Sequence[jax.Device]]: + """Returns the mapping from slice index to devices.""" + slice_to_devices = collections.defaultdict(list) + for d in devices: + slice_to_devices[d.slice_index].append(d) + return dict(slice_to_devices) + + +@timing.timeit +def get_active_slice_indices( + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Returns the set of active slices indices.""" + if slice_to_devices is None: + slice_to_devices = get_slice_to_devices(tuple(jax.devices())) + + active_slice_indices = set() + + results = { + slice_index: _simple_execution(devices) + for slice_index, devices in slice_to_devices.items() + } + + for slice_index, x in results.items(): + _logger.info("Checking slice_index=%s", slice_index) + expected = ( + np.zeros(len(slice_to_devices[slice_index]), dtype=float) + + _SIMPLE_EXECUTION_TEST_VALUE + ) + try: + with timing.Timer(f"Checking {slice_index=}"): + jax.block_until_ready(x) + if np.allclose(x, expected): + active_slice_indices.add(slice_index) + _logger.info("slice_index=%s active", slice_index) + else: + _logger.error( + "Error with _simple_execution for slice_index=%s. " + "This should never happen. Expected: %s, Actual: %s", + slice_index, + expected, + x, + ) + raise ValueError( + f"Error with _simple_execution for slice_index={slice_index}." + ) + except jax.errors.JaxRuntimeError as error: + if not is_error_due_to_slice_down(error): + raise + _logger.info("slice_index=%s bad", slice_index) + + _logger.info("active_slice_indices=%s", active_slice_indices) + + return active_slice_indices + + +def wait_for_slices( + slice_count: int, + poll_interval: float | int = 10, + timeout: float | int | None = None, + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Waits until after at least `slice_count` slices become active. + + Args: + slice_count: The number of slices to wait for. + poll_interval: The minimum number of seconds to wait between availability + checks. If the check takes longer than this, the next check will start + immediately after the current check completes. Defaults to 10 seconds. + timeout: The maximum number of seconds to wait. If None, there is no + timeout. + slice_to_devices: A mapping from slice index to devices. If None, + `get_slice_to_devices(jax.devices())` is used. + + Returns: + The active slice indices + + Raises: + TimeoutError: If the timeout is reached before the slices become + active. + """ + if slice_to_devices is None: + slice_to_devices = get_slice_to_devices(jax.devices()) + + start_time = time.time() + + while True: + check_start_time = time.time() + + active_slice_indices = get_active_slice_indices(slice_to_devices) + if len(active_slice_indices) >= slice_count: + _logger.info("%s slices active.", len(active_slice_indices)) + return active_slice_indices + + _logger.info( + "%s slices active. Wanting at least %s.", + len(active_slice_indices), + slice_count, + ) + + time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) + + if ( + timeout is not None + and (elapsed_time := time.time() - start_time) + time_to_sleep + >= timeout + ): + raise TimeoutError( + f"Timed out waiting for {slice_count} slices. Only" + f" {len(active_slice_indices)} active after" + f" {elapsed_time:.2f} seconds." + f" Next check would occur after the timeout of {timeout}" + " seconds." + ) + + if time_to_sleep > 0: + _logger.info("Sleeping for %.2f seconds.", time_to_sleep) + + time.sleep(time_to_sleep) + + +def is_error_due_to_slice_down(error: Exception) -> bool: + """Returns True if the error is due to slice down. + + The error types that are considered due to slice down are + jax.errors.JaxRuntimeError with the following error kind in the message: + - DATA_LOSS + - DEADLINE_EXCEEDED + - NOT_FOUND + - INTERNAL + + Args: + error: The error to check. + """ + error_due_to_slice_down = False + traceback_logging_level = logging.DEBUG + + if isinstance(error, jax.errors.JaxRuntimeError): + if any( + error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES + ): + _logger.info("Caught an error due to slice down") + + error_due_to_slice_down = True + + elif any( + error_type in str(error) + for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES + ): + _logger.warning( + "Caught an error due that may or may not be due to slice down. This" + " error will be treated as due to slice down." + ) + traceback_logging_level = logging.WARNING + + error_due_to_slice_down = True + + if not error_due_to_slice_down: + _logger.info("Caught an error not due to slice down") + + _logger.log( + traceback_logging_level, "\n".join(traceback.format_exception(error)) + ) + + return error_due_to_slice_down diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 8bd712e..cb72bf5 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -18,37 +18,18 @@ events. It also provides a utility for waiting for slices to become active. """ -import collections from collections.abc import Mapping, Sequence import functools -import itertools import logging -import time -import traceback from typing import Any import jax -import numpy as np -from pathwaysutils.debug import timing +from pathwaysutils.elastic import elastic _logger = logging.getLogger(__name__) -def _plus_one(x: jax.Array) -> jax.Array: - """Adds one to each element in the array. - - Used to test if a slice is active. - - Args: - x: The array to add one to. - - Returns: - The array with one added to each element. - """ - return x + 1 - - class ElasticRuntimeError(RuntimeError): """Error raised when elasticity cannot continue.""" @@ -56,21 +37,10 @@ class ElasticRuntimeError(RuntimeError): class Manager: """Utility class for elastic training.""" - _devices: Sequence[jax.Device] _total_slice_count: int | None = None slice_to_devices: Mapping[int, Sequence[jax.Device]] active_slice_indices: set[int] - _SIMPLE_EXECUTION_TEST_VALUE = 100 - _ELASTIC_DOWN_ERROR_TYPES = [ - "DATA_LOSS", - ] - _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ - "DEADLINE_EXCEEDED", - "NOT_FOUND", - "INTERNAL", - ] - def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """Initializes the manager. @@ -79,24 +49,11 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """ if devices is None: devices = jax.devices() - self.devices = devices - - self.active_slice_indices = self.get_active_slice_indices() - - @property - def devices(self) -> Sequence[jax.Device]: - """Returns the devices.""" - return self._devices + self.slice_to_devices = elastic.get_slice_to_devices(devices) - @devices.setter - def devices(self, devices: Sequence[jax.Device]) -> None: - """Sets the devices.""" - self._devices = devices - - self.slice_to_devices = collections.defaultdict(list) - for d in self._devices: - self.slice_to_devices[d.slice_index].append(d) - self.slice_to_devices = dict(self.slice_to_devices) + self.active_slice_indices = elastic.get_active_slice_indices( + slice_to_devices=self.slice_to_devices + ) @property def total_slice_count(self) -> int: @@ -105,143 +62,6 @@ def total_slice_count(self) -> int: self._total_slice_count = len(self.slice_to_devices) return self._total_slice_count - def slice_device_count(self, slice_index: int) -> int: - """Returns the number of devices in a slice.""" - try: - return len(self.slice_to_devices[slice_index]) - except KeyError as error: - raise ValueError( - f"Slice {slice_index=} not found in {self.slice_to_devices=}" - ) from error - - def is_error_due_to_slice_down(self, error: Exception) -> bool: - """Returns True if the error is due to slice down. - - The error types that are considered due to slice down are - jax.errors.JaxRuntimeError with the following error kind in the message: - - DATA_LOSS - - DEADLINE_EXCEEDED - - NOT_FOUND - - INTERNAL - - Args: - error: The error to check. - """ - error_due_to_slice_down = False - traceback_logging_level = logging.DEBUG - - if isinstance(error, jax.errors.JaxRuntimeError): - if any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ERROR_TYPES - ): - _logger.info("Caught an error due to slice down") - - error_due_to_slice_down = True - - elif any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES - ): - _logger.warning( - "Caught an error due that may or may not be due to slice down. This" - " error will be treated as due to slice down." - ) - traceback_logging_level = logging.WARNING - - error_due_to_slice_down = True - - if not error_due_to_slice_down: - _logger.info("Caught an error not due to slice down") - - _logger.log( - traceback_logging_level, "\n".join(traceback.format_exception(error)) - ) - - return error_due_to_slice_down - - def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: - """Simple execution to test if a slice is active. - - This function is used to test if a slice is active. It executes a simple - computation on the devices and returns the result. If any of the devices are - not active, the returned array will fail with a JaxRuntimeError used. - - Simply executing this function is not enough to determine if the slice is - active. We also need to check the value of the returned array. - - Args: - devices: The devices to execute on. - - Returns: - The result of the execution. - """ - if not devices: - raise ValueError("No devices") - - test_input = np.zeros(len(devices), dtype=float) + ( - self._SIMPLE_EXECUTION_TEST_VALUE - 1 - ) - - return jax.pmap(_plus_one, devices=devices)(test_input) - - @timing.timeit - def get_active_slice_indices(self) -> set[int]: - """Returns the set of active slices indices.""" - active_slice_indices = set() - - results = { - slice_index: self._simple_execution(devices) - for slice_index, devices in self.slice_to_devices.items() - } - - for slice_index, x in results.items(): - _logger.info("Checking slice_index=%s", slice_index) - expected = ( - np.zeros(self.slice_device_count(slice_index), dtype=float) - + self._SIMPLE_EXECUTION_TEST_VALUE - ) - try: - with timing.Timer(f"Checking {slice_index=}"): - jax.block_until_ready(x) - if np.allclose(x, expected): - active_slice_indices.add(slice_index) - _logger.info("slice_index=%s good", slice_index) - else: - _logger.error( - "Error with _simple_execution for slice_index=%s. " - "This should never happen. Expected: %s, Actual: %s", - slice_index, - expected, - x, - ) - raise ValueError( - f"Error with _simple_execution for slice_index={slice_index}." - ) - except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): - raise - _logger.info("slice_index=%s bad", slice_index) - - _logger.info("active_slice_indices=%s", active_slice_indices) - - return active_slice_indices - - @property - def active_slice_to_devices(self) -> dict[int, Sequence[jax.Device]]: - """The mapping from a active slice to its devices.""" - return { - slice_index: self.slice_to_devices[slice_index] - for slice_index in self.active_slice_indices - } - - @property - def active_devices(self) -> list[jax.Device]: - """Returns the active slice indices.""" - return list( - itertools.chain.from_iterable(self.active_slice_to_devices.values()) - ) - @property def default_device(self) -> jax.Device: """Returns the device that should be set to the default device. @@ -259,14 +79,14 @@ def active_slice_count(self) -> int: return len(self.active_slice_indices) def scale_by_active_slices(self, x: int | float) -> int | float: - """Scale x by the number of good slices.""" + """Scale x by the number of active slices.""" if isinstance(x, int): quotient, remainder = divmod( x * self.active_slice_count, self.total_slice_count ) if remainder: raise ValueError( - f"Cannot scale {x=} by good slices because it will result in a " + f"Cannot scale {x=} by active slices because it will result in a " f"remainder of {remainder=}." ) return quotient @@ -275,75 +95,6 @@ def scale_by_active_slices(self, x: int | float) -> int | float: else: raise ValueError(f"Unsupported type: {type(x)=}") - def wait_for_slices( - self, - slice_count: int | None = None, - poll_interval: float | int = 10, - timeout: float | int | None = None, - ) -> set[int]: - """Waits until after at least `slice_count` slices become active. - - Args: - slice_count: The number of slices to wait for. If None, waits for all - slices to become active. - poll_interval: The minimum number of seconds to wait between availability - checks. If the check takes longer than this, the next check will start - immediately after the current check completes. Defaults to 10 seconds. - timeout: The maximum number of seconds to wait. If None, there is no - timeout. - - Returns: - The active slice indices - - Raises: - TimeoutError: If the timeout is reached before the slices become - active. - """ - if slice_count is None: - slice_count = self.total_slice_count - - start_time = time.time() - - while True: - check_start_time = time.time() - - active_slice_indices = self.get_active_slice_indices() - if len(active_slice_indices) >= slice_count: - _logger.info( - "%s/%s slices are active", - len(active_slice_indices), - self.total_slice_count, - ) - return active_slice_indices - - _logger.info( - "%s/%s slices active. Wanting at least %s/%s.", - len(active_slice_indices), - self.total_slice_count, - slice_count, - self.total_slice_count, - ) - - time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) - - if ( - timeout is not None - and (elapsed_time := time.time() - start_time) + time_to_sleep - >= timeout - ): - raise TimeoutError( - f"Timed out waiting for {slice_count} slices. Only" - f" {len(active_slice_indices)} active after" - f" {elapsed_time:.2f} seconds." - f" Next check would occur after the timeout of {timeout}" - " seconds." - ) - - if time_to_sleep > 0: - _logger.info("Sleeping for %.2f seconds.", time_to_sleep) - - time.sleep(time_to_sleep) - def pause_resume( self, max_retries: int, @@ -388,13 +139,16 @@ def wrapper(*args, **kwargs): "Elastic attempt %d out of %d", retry_index + 1, max_retries ) - self.active_slice_indices = self.wait_for_slices( - poll_interval=poll_interval, timeout=timeout + self.active_slice_indices = elastic.wait_for_slices( + slice_count=self.total_slice_count, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + timeout=timeout, ) return func(*args, **kwargs) except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): + if not elastic.is_error_due_to_slice_down(error): raise try: