Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 111 additions & 168 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@

import base64
import collections
from collections.abc import Mapping
import json
import logging
import math
import operator
from typing import Any, Callable, Dict, Mapping, Sequence
import warnings

import jax
from pathwaysutils import jax as pw_jax
from pathwaysutils import lru_cache
from pathwaysutils import plugin_executable
from pathwaysutils import reshard as pw_reshard
from pathwaysutils.experimental import split_by_mesh_axis


Expand Down Expand Up @@ -116,78 +117,6 @@ def _get_resharding_plan(
_get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan)


def _reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool,
may_alias: bool | None,
jax_array_reshard_fn: Callable[..., Any],
**kwargs,
) -> Any:
"""Reshards `x` to `sharding`."""
flat_x, tree_def = jax.tree.flatten(x)
flat_sharding = jax.api_util.flatten_axes(
"reshard sharding", tree_def, sharding
)

# We must split the arrays into two groups:
# 1. jax.Array
# 2. non jax.Array
# For jax.Array, we will use the ifrt client to get the resharding plan and
# execute it.
# These arrays must be further split into groups based on the device set of
# the sharding, since plugin programs only supports execution on the same
# device set.
# For non jax.Array, we will use jax.device_put to put the array to the
# destination devices.
#
# We need to track what index each array is in the original pytree, so we can
# put them back together in the right order.
array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []}
jax_arrays = collections.defaultdict(array_info_lambda)
non_reshardable_arrays = array_info_lambda()
for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `jax.sharding.Sharding`")
if not isinstance(arr, jax.Array) or (
hasattr(arr, "dtype")
and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
):
non_reshardable_arrays["arrays"].append(arr)
non_reshardable_arrays["indices"].append(index)
non_reshardable_arrays["dst_shardings"].append(dst_sharding)
else:
device_set = frozenset(arr.sharding.device_set)
jax_arrays[device_set]["arrays"].append(arr)
jax_arrays[device_set]["indices"].append(index)
jax_arrays[device_set]["dst_shardings"].append(dst_sharding)

if non_reshardable_arrays["arrays"]:
non_reshardable_arrays["arrays"] = jax.device_put(
non_reshardable_arrays["arrays"],
non_reshardable_arrays["dst_shardings"],
donate=donate,
may_alias=may_alias,
)

for array_info in jax_arrays.values():
array_info["arrays"] = jax_array_reshard_fn(
array_info, donate=donate, **kwargs
)

result = [None] * len(flat_x)
for arr, idx in zip(
non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"]
):
result[idx] = arr
for array_info in jax_arrays.values():
for arr, idx in zip(array_info["arrays"], array_info["indices"]):
result[idx] = arr

return jax.tree.unflatten(tree_def, result)


def _sidechannel_jax_array_reshard(
array_info: Mapping[str, Any], *, donate: bool, cache_resharding_plans: bool
) -> Sequence[jax.Array]:
Expand All @@ -214,61 +143,6 @@ def _ifrt_jax_array_reshard(
)


def _reshard_with_sidechannel(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool,
may_alias: bool | None,
cache_resharding_plans: bool,
) -> Any:
"""Reshards `x` to `sharding` using sidechannel."""
return _reshard(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
cache_resharding_plans=cache_resharding_plans,
)


def _reshard_with_ifrt(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate: bool,
may_alias: bool | None,
) -> Any:
"""Reshards `x` to `sharding` using IFRT.

Note: Resharding plan caching is not applicable to the IFRT implementation
and is not supported by this function.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
(must be a tree prefix of `x`), representing the device(s) and sharding to
which `x` should be sharded to. The result will be committed to the
device(s) of the sharding.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.

Returns:
A copy of `x` whose sharding is `sharding`.
"""
return _reshard(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_ifrt_jax_array_reshard,
)


def reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
Expand All @@ -279,6 +153,9 @@ def reshard(
) -> Any:
"""Reshards `x` to `sharding`.

This function is an alternative to `pathwaysutils.reshard` that uses the
sidechannel resharding API for the final reshard.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
Expand All @@ -291,38 +168,19 @@ def reshard(
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
recreating plans for the same resharding operation. May improve
performance for use cases where the same resharding operation is done
many times. May degrade performance if most reshardings operations are
different, since the cache will cause Pathways Components to remain
loaded for each cached plan. `False` by default. This parameter is only
used when `pw_jax.ifrt_reshard_available()` is false.
recreating plans for the same resharding operation.

Returns:
A copy of `x` whose sharding is `sharding`.
"""
if pw_jax.ifrt_reshard_available():
if cache_resharding_plans:
warnings.warn(
"`cache_resharding_plans` is only applicable when using the"
" sidechannel resharding implementation, but IFRT resharding is"
" available and will be used. The `cache_resharding_plans` argument"
" will be ignored."
)
return _reshard_with_ifrt(
x,
sharding,
donate=donate,
may_alias=may_alias,
)
else:
return _reshard_with_sidechannel(
x,
sharding,
donate=donate,
may_alias=may_alias,
cache_resharding_plans=cache_resharding_plans,
)
return pw_reshard.reshard_generic(
x,
sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
cache_resharding_plans=cache_resharding_plans,
)


class NoIntermediateShardingError(Exception):
Expand Down Expand Up @@ -564,40 +422,38 @@ def find_intermediate_sharding(
return intermediate_sharding, replicated_axes


def reshard_with_intermediate_sharding(
def reshard_with_intermediate_sharding_generic(
x: Any,
in_sharding: jax.sharding.Sharding,
out_sharding: jax.sharding.Sharding,
*,
jax_array_reshard_fn: Callable[..., Sequence[jax.Array]],
donate: bool = False,
may_alias: bool | None = None,
cache_resharding_plans: bool = False,
**kwargs: Any,
) -> Any:
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.

This function is an alternative to `reshard` that may be faster and sometime
essential for certain sharding combinations by using an intermediate sharding
to avoid expensive all-gathers. If no beneficial intermediate sharding is
found, it falls back to standard resharding. See `find_intermediate_sharding`
for more details on when an intermediate sharding is used.
This function is a generic version of `reshard_with_intermediate_sharding`
that allows specifying the `jax_array_reshard_fn` to be used for the final
reshard.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
in_sharding: The source sharding of `x`.
out_sharding: The target sharding for `x`.
jax_array_reshard_fn: The function used for the final reshard of JAX arrays.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.
cache_resharding_plans: Only used when resharding with sidechannel. If
`True`, uses a resharding plan cache to avoid recreating plans for the
same resharding operation.
**kwargs: Additional keyword arguments to be passed to
`jax_array_reshard_fn`.

Returns:
A copy of `x` whose sharding is `out_sharding`.
"""

try:
intermediate_sharding, replicated_axes_names = find_intermediate_sharding(
in_sharding, out_sharding
Expand All @@ -617,10 +473,97 @@ def reshard_with_intermediate_sharding(
donate=donate,
)

return reshard(
return pw_reshard.reshard_generic(
x_to_reshard,
out_sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=jax_array_reshard_fn,
**kwargs,
)


def reshard_with_intermediate_sharding(
x: Any,
in_sharding: jax.sharding.Sharding,
out_sharding: jax.sharding.Sharding,
*,
donate: bool = False,
may_alias: bool | None = None,
) -> Any:
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.

This function is an alternative to `reshard` that may be faster and sometime
essential for certain sharding combinations by using an intermediate sharding
to avoid expensive all-gathers. If no beneficial intermediate sharding is
found, it falls back to standard resharding. See `find_intermediate_sharding`
for more details on when an intermediate sharding is used.

Uses the IFRT resharding API for the final reshard.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
in_sharding: The source sharding of `x`.
out_sharding: The target sharding for `x`.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.

Returns:
A copy of `x` whose sharding is `out_sharding`.
"""
return reshard_with_intermediate_sharding_generic(
x,
in_sharding,
out_sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_ifrt_jax_array_reshard,
)


def sidechannel_reshard_with_intermediate_sharding(
x: Any,
in_sharding: jax.sharding.Sharding,
out_sharding: jax.sharding.Sharding,
*,
donate: bool = False,
may_alias: bool | None = None,
cache_resharding_plans: bool = False,
) -> Any:
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.

This function is an alternative to `reshard` that may be faster and sometime
essential for certain sharding combinations by using an intermediate sharding
to avoid expensive all-gathers. If no beneficial intermediate sharding is
found, it falls back to standard resharding. See `find_intermediate_sharding`
for more details on when an intermediate sharding is used.

Uses the sidechannel resharding API for the final reshard.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
in_sharding: The source sharding of `x`.
out_sharding: The target sharding for `x`.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
recreating plans for the same resharding operation.

Returns:
A copy of `x` whose sharding is `out_sharding`.
"""
return reshard_with_intermediate_sharding_generic(
x,
in_sharding,
out_sharding,
donate=donate,
may_alias=may_alias,
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
cache_resharding_plans=cache_resharding_plans,
)
Loading
Loading