mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Optimize array shard reordering
This change adds a C++ implementation that uses `xla::ifrt::RemapArrays` to reorder shards of an array. This avoids creating intermediate single-device arrays and accelerates reordering shards within `jax.device_put()` implementation. PiperOrigin-RevId: 718998621
This commit is contained in:
parent
8442d64a02
commit
1d016962c5
@ -46,6 +46,7 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -364,11 +365,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
|
||||
new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind,
|
||||
_logical_device_ids=(None if permute_order is None else
|
||||
tuple(permute_order.tolist())))
|
||||
new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
|
||||
new_x = _reorder_shards(x, new_s, CopySemantics.ALIAS)
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding,
|
||||
donate_argnums=donate_argnums)(new_x)
|
||||
|
||||
|
||||
def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
|
||||
"""Reorders array shards to match the order indicated by the new sharding."""
|
||||
if xla_extension_version >= 304:
|
||||
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
|
||||
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
|
||||
else:
|
||||
assert copy_semantics == CopySemantics.ALIAS
|
||||
return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeferredShardArg:
|
||||
"""Deferred call to `pxla.shard_args`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user