diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4893e8335..7dea452c8 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -46,6 +46,7 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec @@ -364,11 +365,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, _logical_device_ids=(None if permute_order is None else tuple(permute_order.tolist()))) - new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + new_x = _reorder_shards(x, new_s, CopySemantics.ALIAS) return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(new_x) +def _reorder_shards(x, new_s, copy_semantics: CopySemantics): + """Reorders array shards to match the order indicated by the new sharding.""" + if xla_extension_version >= 304: + xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] + return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore + else: + assert copy_semantics == CopySemantics.ALIAS + return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + + @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`.