diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 5cfbd37..ca0b58d 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -15,17 +15,18 @@ import base64 import collections +from collections.abc import Mapping import json import logging import math import operator from typing import Any, Callable, Dict, Mapping, Sequence -import warnings import jax from pathwaysutils import jax as pw_jax from pathwaysutils import lru_cache from pathwaysutils import plugin_executable +from pathwaysutils import reshard as pw_reshard from pathwaysutils.experimental import split_by_mesh_axis @@ -116,78 +117,6 @@ def _get_resharding_plan( _get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan) -def _reshard( - x: Any, - sharding: jax.sharding.Sharding | Any, - *, - donate: bool, - may_alias: bool | None, - jax_array_reshard_fn: Callable[..., Any], - **kwargs, -) -> Any: - """Reshards `x` to `sharding`.""" - flat_x, tree_def = jax.tree.flatten(x) - flat_sharding = jax.api_util.flatten_axes( - "reshard sharding", tree_def, sharding - ) - - # We must split the arrays into two groups: - # 1. jax.Array - # 2. non jax.Array - # For jax.Array, we will use the ifrt client to get the resharding plan and - # execute it. - # These arrays must be further split into groups based on the device set of - # the sharding, since plugin programs only supports execution on the same - # device set. - # For non jax.Array, we will use jax.device_put to put the array to the - # destination devices. - # - # We need to track what index each array is in the original pytree, so we can - # put them back together in the right order. - array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []} - jax_arrays = collections.defaultdict(array_info_lambda) - non_reshardable_arrays = array_info_lambda() - for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)): - if not isinstance(dst_sharding, jax.sharding.Sharding): - raise ValueError("`sharding` must contain only `jax.sharding.Sharding`") - if not isinstance(arr, jax.Array) or ( - hasattr(arr, "dtype") - and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) - ): - non_reshardable_arrays["arrays"].append(arr) - non_reshardable_arrays["indices"].append(index) - non_reshardable_arrays["dst_shardings"].append(dst_sharding) - else: - device_set = frozenset(arr.sharding.device_set) - jax_arrays[device_set]["arrays"].append(arr) - jax_arrays[device_set]["indices"].append(index) - jax_arrays[device_set]["dst_shardings"].append(dst_sharding) - - if non_reshardable_arrays["arrays"]: - non_reshardable_arrays["arrays"] = jax.device_put( - non_reshardable_arrays["arrays"], - non_reshardable_arrays["dst_shardings"], - donate=donate, - may_alias=may_alias, - ) - - for array_info in jax_arrays.values(): - array_info["arrays"] = jax_array_reshard_fn( - array_info, donate=donate, **kwargs - ) - - result = [None] * len(flat_x) - for arr, idx in zip( - non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"] - ): - result[idx] = arr - for array_info in jax_arrays.values(): - for arr, idx in zip(array_info["arrays"], array_info["indices"]): - result[idx] = arr - - return jax.tree.unflatten(tree_def, result) - - def _sidechannel_jax_array_reshard( array_info: Mapping[str, Any], *, donate: bool, cache_resharding_plans: bool ) -> Sequence[jax.Array]: @@ -214,61 +143,6 @@ def _ifrt_jax_array_reshard( ) -def _reshard_with_sidechannel( - x: Any, - sharding: jax.sharding.Sharding | Any, - *, - donate: bool, - may_alias: bool | None, - cache_resharding_plans: bool, -) -> Any: - """Reshards `x` to `sharding` using sidechannel.""" - return _reshard( - x, - sharding, - donate=donate, - may_alias=may_alias, - jax_array_reshard_fn=_sidechannel_jax_array_reshard, - cache_resharding_plans=cache_resharding_plans, - ) - - -def _reshard_with_ifrt( - x: Any, - sharding: jax.sharding.Sharding | Any, - *, - donate: bool, - may_alias: bool | None, -) -> Any: - """Reshards `x` to `sharding` using IFRT. - - Note: Resharding plan caching is not applicable to the IFRT implementation - and is not supported by this function. - - Args: - x: An array, scalar, or (nested) standard Python container thereof. - sharding: A `Sharding` or a (nested) `Sharding` in standard Python container - (must be a tree prefix of `x`), representing the device(s) and sharding to - which `x` should be sharded to. The result will be committed to the - device(s) of the sharding. - donate: If `True`, donate all input arrays, which may reduce the amount of - memory needed for resharding. Buffers donated to resharding should not be - reused. - may_alias: If `True`, may alias the input array with the output array. May - reduce the amount of memory needed for resharding. Not used at the moment. - - Returns: - A copy of `x` whose sharding is `sharding`. - """ - return _reshard( - x, - sharding, - donate=donate, - may_alias=may_alias, - jax_array_reshard_fn=_ifrt_jax_array_reshard, - ) - - def reshard( x: Any, sharding: jax.sharding.Sharding | Any, @@ -279,6 +153,9 @@ def reshard( ) -> Any: """Reshards `x` to `sharding`. + This function is an alternative to `pathwaysutils.reshard` that uses the + sidechannel resharding API for the final reshard. + Args: x: An array, scalar, or (nested) standard Python container thereof. sharding: A `Sharding` or a (nested) `Sharding` in standard Python container @@ -291,38 +168,19 @@ def reshard( may_alias: If `True`, may alias the input array with the output array. May reduce the amount of memory needed for resharding. Not used at the moment. cache_resharding_plans: If `True`, uses a resharding plan cache to avoid - recreating plans for the same resharding operation. May improve - performance for use cases where the same resharding operation is done - many times. May degrade performance if most reshardings operations are - different, since the cache will cause Pathways Components to remain - loaded for each cached plan. `False` by default. This parameter is only - used when `pw_jax.ifrt_reshard_available()` is false. + recreating plans for the same resharding operation. Returns: A copy of `x` whose sharding is `sharding`. """ - if pw_jax.ifrt_reshard_available(): - if cache_resharding_plans: - warnings.warn( - "`cache_resharding_plans` is only applicable when using the" - " sidechannel resharding implementation, but IFRT resharding is" - " available and will be used. The `cache_resharding_plans` argument" - " will be ignored." - ) - return _reshard_with_ifrt( - x, - sharding, - donate=donate, - may_alias=may_alias, - ) - else: - return _reshard_with_sidechannel( - x, - sharding, - donate=donate, - may_alias=may_alias, - cache_resharding_plans=cache_resharding_plans, - ) + return pw_reshard.reshard_generic( + x, + sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_sidechannel_jax_array_reshard, + cache_resharding_plans=cache_resharding_plans, + ) class NoIntermediateShardingError(Exception): @@ -564,40 +422,38 @@ def find_intermediate_sharding( return intermediate_sharding, replicated_axes -def reshard_with_intermediate_sharding( +def reshard_with_intermediate_sharding_generic( x: Any, in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding, *, + jax_array_reshard_fn: Callable[..., Sequence[jax.Array]], donate: bool = False, may_alias: bool | None = None, - cache_resharding_plans: bool = False, + **kwargs: Any, ) -> Any: """Reshards `x` to `out_sharding`, using an intermediate sharding if possible. - This function is an alternative to `reshard` that may be faster and sometime - essential for certain sharding combinations by using an intermediate sharding - to avoid expensive all-gathers. If no beneficial intermediate sharding is - found, it falls back to standard resharding. See `find_intermediate_sharding` - for more details on when an intermediate sharding is used. + This function is a generic version of `reshard_with_intermediate_sharding` + that allows specifying the `jax_array_reshard_fn` to be used for the final + reshard. Args: x: An array, scalar, or (nested) standard Python container thereof. in_sharding: The source sharding of `x`. out_sharding: The target sharding for `x`. + jax_array_reshard_fn: The function used for the final reshard of JAX arrays. donate: If `True`, donate all input arrays, which may reduce the amount of memory needed for resharding. Buffers donated to resharding should not be reused. may_alias: If `True`, may alias the input array with the output array. May reduce the amount of memory needed for resharding. Not used at the moment. - cache_resharding_plans: Only used when resharding with sidechannel. If - `True`, uses a resharding plan cache to avoid recreating plans for the - same resharding operation. + **kwargs: Additional keyword arguments to be passed to + `jax_array_reshard_fn`. Returns: A copy of `x` whose sharding is `out_sharding`. """ - try: intermediate_sharding, replicated_axes_names = find_intermediate_sharding( in_sharding, out_sharding @@ -617,10 +473,97 @@ def reshard_with_intermediate_sharding( donate=donate, ) - return reshard( + return pw_reshard.reshard_generic( x_to_reshard, out_sharding, donate=donate, may_alias=may_alias, + jax_array_reshard_fn=jax_array_reshard_fn, + **kwargs, + ) + + +def reshard_with_intermediate_sharding( + x: Any, + in_sharding: jax.sharding.Sharding, + out_sharding: jax.sharding.Sharding, + *, + donate: bool = False, + may_alias: bool | None = None, +) -> Any: + """Reshards `x` to `out_sharding`, using an intermediate sharding if possible. + + This function is an alternative to `reshard` that may be faster and sometime + essential for certain sharding combinations by using an intermediate sharding + to avoid expensive all-gathers. If no beneficial intermediate sharding is + found, it falls back to standard resharding. See `find_intermediate_sharding` + for more details on when an intermediate sharding is used. + + Uses the IFRT resharding API for the final reshard. + + Args: + x: An array, scalar, or (nested) standard Python container thereof. + in_sharding: The source sharding of `x`. + out_sharding: The target sharding for `x`. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. + + Returns: + A copy of `x` whose sharding is `out_sharding`. + """ + return reshard_with_intermediate_sharding_generic( + x, + in_sharding, + out_sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_ifrt_jax_array_reshard, + ) + + +def sidechannel_reshard_with_intermediate_sharding( + x: Any, + in_sharding: jax.sharding.Sharding, + out_sharding: jax.sharding.Sharding, + *, + donate: bool = False, + may_alias: bool | None = None, + cache_resharding_plans: bool = False, +) -> Any: + """Reshards `x` to `out_sharding`, using an intermediate sharding if possible. + + This function is an alternative to `reshard` that may be faster and sometime + essential for certain sharding combinations by using an intermediate sharding + to avoid expensive all-gathers. If no beneficial intermediate sharding is + found, it falls back to standard resharding. See `find_intermediate_sharding` + for more details on when an intermediate sharding is used. + + Uses the sidechannel resharding API for the final reshard. + + Args: + x: An array, scalar, or (nested) standard Python container thereof. + in_sharding: The source sharding of `x`. + out_sharding: The target sharding for `x`. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. + cache_resharding_plans: If `True`, uses a resharding plan cache to avoid + recreating plans for the same resharding operation. + + Returns: + A copy of `x` whose sharding is `out_sharding`. + """ + return reshard_with_intermediate_sharding_generic( + x, + in_sharding, + out_sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_sidechannel_jax_array_reshard, cache_resharding_plans=cache_resharding_plans, ) diff --git a/pathwaysutils/reshard.py b/pathwaysutils/reshard.py new file mode 100644 index 0000000..5046175 --- /dev/null +++ b/pathwaysutils/reshard.py @@ -0,0 +1,136 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Experimental resharding API for elastic device sets.""" + +import collections +import logging +from typing import Any, Callable, Mapping, Sequence + +import jax +import pathwaysutils.jax + + +_logger = logging.getLogger(__name__) + + +def reshard_generic( + x: Any, + sharding: jax.sharding.Sharding | Any, + *, + donate: bool, + may_alias: bool | None, + jax_array_reshard_fn: Callable[..., Any], + **kwargs, +) -> Any: + """Reshards `x` to `sharding`.""" + flat_x, tree_def = jax.tree.flatten(x) + flat_sharding = jax.api_util.flatten_axes( + "reshard sharding", tree_def, sharding + ) + + # We must split the arrays into two groups: + # 1. jax.Array + # 2. non jax.Array + # For jax.Array, we will use the provided `jax_array_reshard_fn`. + # For non jax.Array, we will use jax.device_put to put the array to the + # destination devices. This is necessary for new-style random keys and + # possibly other types of arrays. + # + # We need to track what index each array is in the original pytree, so we can + # put them back together in the right order. + array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []} + jax_arrays = collections.defaultdict(array_info_lambda) + non_reshardable_arrays = array_info_lambda() + for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)): + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `jax.sharding.Sharding`") + if not isinstance(arr, jax.Array) or ( + hasattr(arr, "dtype") + and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) + ): + non_reshardable_arrays["arrays"].append(arr) + non_reshardable_arrays["indices"].append(index) + non_reshardable_arrays["dst_shardings"].append(dst_sharding) + else: + device_set = frozenset(arr.sharding.device_set) + jax_arrays[device_set]["arrays"].append(arr) + jax_arrays[device_set]["indices"].append(index) + jax_arrays[device_set]["dst_shardings"].append(dst_sharding) + + if non_reshardable_arrays["arrays"]: + non_reshardable_arrays["arrays"] = jax.device_put( + non_reshardable_arrays["arrays"], + non_reshardable_arrays["dst_shardings"], + donate=donate, + may_alias=may_alias, + ) + + for array_info in jax_arrays.values(): + array_info["arrays"] = jax_array_reshard_fn( + array_info, donate=donate, **kwargs + ) + + result = [None] * len(flat_x) + for arr, idx in zip( + non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"] + ): + result[idx] = arr + for array_info in jax_arrays.values(): + for arr, idx in zip(array_info["arrays"], array_info["indices"]): + result[idx] = arr + + return jax.tree.unflatten(tree_def, result) + + +def _ifrt_jax_array_reshard( + array_info: Mapping[str, Any], *, donate: bool +) -> Sequence[jax.Array]: + return pathwaysutils.jax.transfer_to_shardings( + tuple(arr for arr in array_info["arrays"]), + tuple(array_info["dst_shardings"]), + donate, + ) + + +def reshard( + x: Any, + sharding: jax.sharding.Sharding | Any, + *, + donate: bool = False, + may_alias: bool | None = None, +) -> Any: + """Reshards `x` to `sharding`. + + Args: + x: An array, scalar, or (nested) standard Python container thereof. + sharding: A `Sharding` or a (nested) `Sharding` in standard Python container + (must be a tree prefix of `x`), representing the device(s) and sharding to + which `x` should be sharded to. The result will be committed to the + device(s) of the sharding. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. + + Returns: + A copy of `x` whose sharding is `sharding`. + """ + return reshard_generic( + x, + sharding, + donate=donate, + may_alias=may_alias, + jax_array_reshard_fn=_ifrt_jax_array_reshard, + ) diff --git a/pathwaysutils/test/experimental/reshard_test.py b/pathwaysutils/test/experimental/reshard_test.py index 9cecd99..ca5acbb 100644 --- a/pathwaysutils/test/experimental/reshard_test.py +++ b/pathwaysutils/test/experimental/reshard_test.py @@ -21,39 +21,12 @@ from absl.testing import parameterized import jax import jax.numpy as jnp -from pathwaysutils import jax as pw_jax from pathwaysutils import plugin_executable from pathwaysutils.experimental import reshard class ReshardTest(parameterized.TestCase): - @parameterized.parameters( - dict(reshard_kwargs={"donate": True}, expected_donate=True), - dict(reshard_kwargs={"donate": False}, expected_donate=False), - dict(reshard_kwargs={}, expected_donate=False), - ) - def test_ifrt_reshard_donate( - self, reshard_kwargs: Mapping[str, Any], expected_donate: bool - ): - x = jnp.array([1, 2]) - devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) - - mock_transfer = self.enter_context( - mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) - ) - self.enter_context( - mock.patch.object( - pw_jax, "ifrt_reshard_available", return_value=True, autospec=True - ) - ) - - reshard.reshard(x, sharding, **reshard_kwargs) - - # Signature: transfer_to_shardings(arrays, shardings, donate) - mock_transfer.assert_called_with(mock.ANY, mock.ANY, expected_donate) - @parameterized.parameters( dict(reshard_kwargs={"donate": True}, expected_donate=True), dict(reshard_kwargs={"donate": False}, expected_donate=False), @@ -66,11 +39,6 @@ def test_sidechannel_reshard_donate( devices = jax.devices() sharding = jax.sharding.SingleDeviceSharding(devices[0]) - self.enter_context( - mock.patch.object( - pw_jax, "ifrt_reshard_available", return_value=False, autospec=True - ) - ) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable", autospec=True) ) @@ -83,31 +51,6 @@ def test_sidechannel_reshard_donate( request = json.loads(json_request) self.assertEqual(request["reshardRequest"]["donateInput"], expected_donate) - @parameterized.parameters(True, False, None) - def test_ifrt_reshard_cache_resharding_plans(self, cache: bool | None): - x = jnp.array([1, 2]) - devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) - - mock_transfer = self.enter_context( - mock.patch.object(pw_jax, "transfer_to_shardings") - ) - self.enter_context( - mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=True) - ) - - if cache is None: - reshard.reshard(x, sharding) - elif cache: - with self.assertWarnsRegex( - UserWarning, "cache_resharding_plans` is only applicable" - ): - reshard.reshard(x, sharding, cache_resharding_plans=cache) - else: - reshard.reshard(x, sharding, cache_resharding_plans=cache) - - mock_transfer.assert_called_once() - @parameterized.parameters( dict(cache=True, expected_cache=True), dict(cache=False, expected_cache=False), @@ -120,9 +63,6 @@ def test_sidechannel_reshard_cache_resharding_plans( devices = jax.devices() sharding = jax.sharding.SingleDeviceSharding(devices[0]) - self.enter_context( - mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=False) - ) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable") ) @@ -144,4 +84,27 @@ def test_sidechannel_reshard_cache_resharding_plans( 1 if expected_cache else 0, ) -if __name__ == "__main__": absltest.main() + def test_sidechannel_reshard_pytree(self): + x = {"a": jnp.array([1]), "b": [jnp.array([2])]} + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + # Tree prefix sharding + tree_sharding = {"a": sharding, "b": [sharding]} + + mock_pe = self.enter_context( + mock.patch.object(plugin_executable, "PluginExecutable", autospec=True) + ) + mock_pe.return_value.call.return_value = ( + [mock.Mock(), mock.Mock()], + mock.Mock(), + ) + + reshard.reshard(x, tree_sharding) + + self.assertEqual(mock_pe.call_count, 1) + (json_request,), _ = mock_pe.call_args + request = json.loads(json_request) + self.assertLen(request["reshardRequest"]["inSharding"], 2) + +if __name__ == "__main__": + absltest.main() diff --git a/pathwaysutils/test/reshard_test.py b/pathwaysutils/test/reshard_test.py new file mode 100644 index 0000000..ce465d0 --- /dev/null +++ b/pathwaysutils/test/reshard_test.py @@ -0,0 +1,73 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping +from typing import Any +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from pathwaysutils import jax as pw_jax +from pathwaysutils import reshard + + +class ReshardTest(parameterized.TestCase): + + @parameterized.parameters( + dict(reshard_kwargs={"donate": True}, expected_donate=True), + dict(reshard_kwargs={"donate": False}, expected_donate=False), + dict(reshard_kwargs={}, expected_donate=False), + ) + def test_ifrt_reshard_donate( + self, reshard_kwargs: Mapping[str, Any], expected_donate: bool + ): + x = jnp.array([1, 2]) + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + + mock_transfer = self.enter_context( + mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) + ) + mock_transfer.return_value = [mock.Mock()] + + reshard.reshard(x, sharding, **reshard_kwargs) + + mock_transfer.assert_called() + args, _ = mock_transfer.call_args + self.assertEqual(args[2], expected_donate) + + def test_ifrt_reshard_pytree(self): + x = {"a": jnp.array([1]), "b": [jnp.array([2])]} + devices = jax.devices() + sharding = jax.sharding.SingleDeviceSharding(devices[0]) + # Tree prefix sharding + tree_sharding = {"a": sharding, "b": [sharding]} + + mock_transfer = self.enter_context( + mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) + ) + mock_transfer.return_value = [mock.Mock(), mock.Mock()] + + reshard.reshard(x, tree_sharding) + + # Since they are on the same device set, they should be grouped together. + self.assertEqual(mock_transfer.call_count, 1) + (arrays, shardings, _), _ = mock_transfer.call_args + self.assertLen(arrays, 2) + self.assertLen(shardings, 2) + +if __name__ == "__main__": + absltest.main()