From 1d016962c5de4b51a37611ebe0b33b83cddf5a28 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 23 Jan 2025 13:45:25 -0800 Subject: [PATCH] [JAX] Optimize array shard reordering This change adds a C++ implementation that uses `xla::ifrt::RemapArrays` to reorder shards of an array. This avoids creating intermediate single-device arrays and accelerates reordering shards within `jax.device_put()` implementation. PiperOrigin-RevId: 718998621 --- jax/_src/dispatch.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4893e8335..7dea452c8 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -46,6 +46,7 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec @@ -364,11 +365,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, _logical_device_ids=(None if permute_order is None else tuple(permute_order.tolist()))) - new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + new_x = _reorder_shards(x, new_s, CopySemantics.ALIAS) return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(new_x) +def _reorder_shards(x, new_s, copy_semantics: CopySemantics): + """Reorders array shards to match the order indicated by the new sharding.""" + if xla_extension_version >= 304: + xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] + return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore + else: + assert copy_semantics == CopySemantics.ALIAS + return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + + @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`.