Remove indices and devices from shard_arg_handlers and shard_args.

This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
This commit is contained in:
Yash Katariya 2024-01-05 14:16:32 -08:00 committed by jax authors
parent ed62f28164
commit b8098b1782
11 changed files with 109 additions and 169 deletions

View File

@ -63,6 +63,7 @@ from jax._src.api_util import (
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
@ -1845,7 +1846,8 @@ def _cpp_pmap(
return out, fastpath_data
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
fun, cache_miss, static_broadcasted_tuple,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)

View File

@ -834,8 +834,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_arg(x._value, devices, indices, sharding,
canonicalize=False)
return pxla.shard_arg(x._value, sharding, canonicalize=False)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
@ -848,10 +847,11 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
def _array_shard_arg(x, devices, indices, sharding):
def _array_shard_arg(x, sharding):
x._check_if_deleted()
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
indices = sharding.addressable_devices_indices_map(x.shape).values()
if not x.is_fully_addressable:
if tuple(x_indices) == tuple(indices):
return x
@ -859,16 +859,15 @@ def _array_shard_arg(x, devices, indices, sharding):
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
devices = pxla.get_addressable_devices_for_shard_arg(sharding)
if tuple(x_indices) == tuple(indices):
return xc.copy_array_to_devices_with_sharding(
x, list(devices), sharding)
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return shard_device_array(x, devices, indices, sharding)
else:
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg

View File

@ -124,12 +124,8 @@ class RuntimeTokenSet(threading.local):
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]),
PartitionSpec([]))
s = jax.sharding.GSPMDSharding.get_replicated(devices)
indices = tuple(
s.addressable_devices_indices_map(tok.shape).values())
sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0]
sharded_tok = pxla.shard_args([s], [tok])[0]
self.current_tokens[eff] = sharded_tok
return sharded_tok
@ -331,8 +327,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False)
map_ = s.devices_indices_map(aval.shape) # type: ignore
return result_handler(pxla.shard_arg(x, list(map_), list(map_.values()), s))
return result_handler(pxla.shard_arg(x, s))
def _override_get_device_assignment(sharding, *args, **kwargs):
da = sharding._device_assignment

View File

@ -106,60 +106,46 @@ ShardingSpec = sharding_specs.ShardingSpec
def identity(x): return x
def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
arguments that are not supported by the C++ `ShardArg`.
Args:
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
def shard_arg(arg, sharding, canonicalize=True):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
return shard_arg_handlers[type(arg)](arg, sharding)
@profiler.annotate_function
def shard_args(
devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
args,
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
) -> Sequence[jax.Array]:
"""Shard each argument data array along its leading axis.
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
Args:
devices: sequence of Devices mapping replica index to a physical device.
indices: sequence of the same length as `args` describing how each arg
should be sharded/replicated across `devices`. Each element in `indices`
is the same length as `devices`.
args: a sequence of JaxTypes representing arguments to be sharded according
to `indices` and placed on `devices`.
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
Returns:
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [shard_arg(arg, devices, indices[i], shardings[i])
for i, arg in enumerate(args)]
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
@lru_cache(maxsize=1024)
def get_addressable_devices_for_shard_arg(
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
return s._addressable_device_assignment
def _shard_token(x, devices, indices, sharding):
@lru_cache(maxsize=1024)
def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices
def _shard_token(x, sharding):
devices = get_addressable_devices_for_shard_arg(sharding)
indices = _get_replicated_slices(len(devices))
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
return batched_device_put(aval, sharding, [zeros for i in indices], devices)
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
shard_arg_handlers[core.Token] = _shard_token
def _masked_array_error(x, devices, indices, sharding):
def _masked_array_error(x, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
def _shard_array(x, devices, indices, sharding):
def _shard_array(x, sharding):
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
devices = get_addressable_devices_for_shard_arg(sharding)
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
@ -167,8 +153,8 @@ def _shard_array(x, devices, indices, sharding):
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_darray(x, devices, indices, sharding):
return shard_arg(x._data, devices, indices, sharding)
def _shard_darray(x, sharding):
return shard_arg(x._data, sharding)
shard_arg_handlers[core.DArray] = _shard_darray
def batched_device_put(aval: core.ShapedArray,
@ -183,7 +169,7 @@ def batched_device_put(aval: core.ShapedArray,
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore
def shard_aval(size, axis: int, aval):
try:
@ -849,8 +835,8 @@ class UnloadedPmapExecutable:
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
handle_args = InputsHandler(self.input_shardings,
self.compiled.local_devices(), input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.unordered_effects,
@ -1054,9 +1040,8 @@ def _get_pmap_sharding(devices, specs):
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
def __init__(self, local_devices, in_shardings, input_indices):
self.handler = partial(
shard_args, local_devices, input_indices, in_shardings)
def __init__(self, in_shardings, local_devices=None, input_indices=None):
self.handler = partial(shard_args, in_shardings)
self.local_devices = local_devices
self.in_shardings = in_shardings
self.input_indices = input_indices
@ -2248,37 +2233,36 @@ class MeshComputation(stages.XlaLowering):
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
@lru_cache(maxsize=1024)
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
if ndim is None:
return ((slice(None),),) * num_addressable_devices
else:
return ((slice(None),) * ndim,) * num_addressable_devices
if xla_extension_version < 229:
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:
input_indices = []
if not isinstance(da_object, _DeviceAssignment):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:
input_indices = []
if not isinstance(da_object, _DeviceAssignment):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = _get_replicated_slices(num_addressable_devices, None)
else:
if sharding.is_fully_replicated:
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
if ndim is None:
return ((slice(None),),) * num_addressable_devices
else:
index = tuple(
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
input_indices.append(index)
return ((slice(None),) * ndim,) * num_addressable_devices
return input_indices
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = _get_replicated_slices(num_addressable_devices, None)
else:
if sharding.is_fully_replicated:
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
else:
index = tuple(
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
input_indices.append(index)
return input_indices
def get_gspmd_shardings_from_executable(
@ -2604,10 +2588,13 @@ class UnloadedMeshExecutable:
all_args_info: AllArgsInfo | None
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
self.device_assignment)
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices)
if xla_extension_version >= 229:
handle_args = InputsHandler(self.input_shardings)
else:
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
self.device_assignment)
handle_args = InputsHandler(
self.input_shardings, self.xla_executable.local_devices(), input_indices)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
@ -2755,6 +2742,7 @@ class MeshExecutableFastpathData(NamedTuple):
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
arg_handler_devices: Sequence[xc.Device]
arg_handler_indices: Sequence[tuple[Index | None, ...]]
@ -2865,13 +2853,20 @@ class MeshExecutable(stages.XlaExecutable):
return outs, fastpath_data
if xla_extension_version >= 226:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, shard_arg)
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry,
shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore
else:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
tree_util.dispatch_registry)
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
return shard_arg(arg, sharding)
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
if len(ref_avals) != len(arg_avals):
@ -2926,20 +2921,15 @@ def _compile_replicated_mesh_executable_from_hlo(
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
if pmap_nreps > 1:
# For a jit wrapping a pmap, replicate each input index to match the
# devices of the replicated jit computation.
input_indices = [index * pmap_nreps for index in input_indices]
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx,
device_assignment=da, ordered_effects=ordered_effects,
in_avals=global_in_avals,
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None

View File

@ -234,7 +234,8 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
@ -1348,9 +1349,11 @@ def _pjit_call_impl(*args, jaxpr,
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
if xla_extension_version >= 226:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(has_explicit_sharding))(*args)
else:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
tree_util.dispatch_registry,

View File

@ -636,20 +636,11 @@ xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
aval = x.aval
key_shape = aval.dtype._impl.key_shape
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
arr = x._base_array
# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
# sharded. This is only true when enable_custom_prng is True.
trailing_inds = [slice(None)] * len(key_shape)
phys_indices = [(*inds, *trailing_inds) for inds in indices]
phys_sharding = make_key_array_phys_sharding(
aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](
arr, devices, phys_indices, phys_sharding
)
x.aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler

View File

@ -30,9 +30,13 @@ XLADeviceAssignment = Sequence[Device]
@functools.lru_cache(maxsize=4096)
def _addressable_devices_indices_map(
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
global_map = sharding.devices_indices_map(global_shape)
if sharding.is_fully_addressable:
return sharding.devices_indices_map(global_shape)
return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items()
return global_map
if hasattr(sharding, '_internal_device_list'):
return {d: global_map[d]
for d in sharding._internal_device_list.addressable_device_list}
return {d: ind for d, ind in global_map.items()
if d.process_index == d.client.process_index()}

View File

@ -110,6 +110,10 @@ class XLACompatibleSharding(sharding.Sharding):
@functools.cached_property
def _addressable_device_assignment(self) -> XLADeviceAssignment:
if self.is_fully_addressable:
return self._device_assignment
if hasattr(self, '_internal_device_list'):
return tuple(self._internal_device_list.addressable_device_list)
return tuple(d for d in self._device_assignment
if d.process_index == d.client.process_index())

View File

@ -3042,8 +3042,8 @@ class FooArray:
size = property(lambda self: self.data.size // 2)
ndim = property(lambda self: self.data.ndim - 1)
def shard_foo_array_handler(x, devices, indices, sharding):
device, = devices
def shard_foo_array_handler(x, sharding):
device, = sharding._addressable_device_assignment
aval = core.raise_to_shaped(core.get_aval(x.data))
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])

View File

@ -45,7 +45,7 @@ from jax.experimental.maps import xmap
from jax.experimental import multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding, _addressable_devices_indices_map
from jax._src.sharding import Sharding
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src.sharding_impls import (
@ -60,7 +60,7 @@ from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
from jax._src.util import curry, unzip2
config.parse_flags_with_absl()
@ -3546,32 +3546,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.devices(), {jax.devices()[1]})
def test_get_indices_cache(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ns2 = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp, ns)
arr2 = jax.device_put(np_inp, ns2)
arr3 = jax.device_put(np_inp, ns)
_addressable_devices_indices_map.cache_clear()
cache_info1 = _addressable_devices_indices_map.cache_info()
out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3)
cache_info2 = _addressable_devices_indices_map.cache_info()
self.assertArraysEqual(out, np_inp * 3)
# arr3 and arr1 should have the same GSPMDSharding objects internally.
# So there will be 2 hits in _addressable_devices_indices_map,
# One in `pxla._get_input_indices` and second in `_array_shard_arg`.
self.assertEqual(cache_info2.hits, cache_info1.hits + 2)
# There will double the amount of misses as hits because arr1 and arr2's
# sharding are not the same. So 2 misses in _addressable_devices_indices_map
# and 2 in _array_shard_arg.
self.assertEqual(cache_info2.misses, cache_info1.misses + 4)
def test_same_named_sharding_pspec_on_eager_ops(self):
mesh = jtu.create_global_mesh((1, 8, 1), ('x', 'y', 'z'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z'))
@ -4261,26 +4235,6 @@ class UtilTest(jtu.JaxTestCase):
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
)
def test_get_input_indices_fully_replicated(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_in_aval1 = core.ShapedArray((4, 4), jnp.int32)
global_in_aval2 = core.ShapedArray((4, 4, 4), jnp.int32)
global_in_aval3 = core.ShapedArray((), jnp.int32)
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
mp = NamedSharding(global_mesh, P(None))
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp],
list(global_mesh.devices.flat))
self.assertLen(out_indices, len(in_avals))
self.assertTrue(all(len(out) == len(global_mesh.local_devices)
for out in out_indices))
self.assertTrue(all(len(i) == aval.ndim
for out, aval in safe_zip(out_indices, in_avals) for i in out))
self.assertTrue(all(i == (slice(None),) * aval.ndim
for out, aval in safe_zip(out_indices, in_avals) for i in out))
@parameterized.named_parameters(
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
("only_unspecified", UNSPECIFIED),

View File

@ -3006,9 +3006,7 @@ class ShardArgsTest(jtu.JaxTestCase):
x = np.arange(math.prod(shape)).reshape(shape)
arg = make_arg(x)
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
results = pxla.shard_args(
jax.devices()[:nshards], [indices], [sharding], [arg]
)
results = pxla.shard_args([sharding], [arg])
self.assertEqual(len(results), 1)
if isinstance(results[0], array.ArrayImpl):
bufs = results[0]._arrays