[JAX] Optimize array shard reordering

This change adds a C++ implementation that uses `xla::ifrt::RemapArrays` to
reorder shards of an array. This avoids creating intermediate single-device
arrays and accelerates reordering shards within `jax.device_put()`
implementation.

PiperOrigin-RevId: 718998621
This commit is contained in:
Hyeontaek Lim 2025-01-23 13:45:25 -08:00 committed by jax authors
parent 8442d64a02
commit 1d016962c5

View File

@ -46,6 +46,7 @@ from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec
@ -364,11 +365,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind,
_logical_device_ids=(None if permute_order is None else
tuple(permute_order.tolist())))
new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
new_x = _reorder_shards(x, new_s, CopySemantics.ALIAS)
return api.jit(_identity_fn, out_shardings=target_sharding,
donate_argnums=donate_argnums)(new_x)
def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
"""Reorders array shards to match the order indicated by the new sharding."""
if xla_extension_version >= 304:
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
else:
assert copy_semantics == CopySemantics.ALIAS
return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
@dataclasses.dataclass(frozen=True)
class _DeferredShardArg:
"""Deferred call to `pxla.shard_args`.