mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing). If your code ends up taking the python dispatch, then something is going wrong anyways. PiperOrigin-RevId: 596081987
This commit is contained in:
parent
ed62f28164
commit
b8098b1782
@ -63,6 +63,7 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
@ -1845,7 +1846,8 @@ def _cpp_pmap(
|
||||
return out, fastpath_data
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
|
||||
fun, cache_miss, static_broadcasted_tuple,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
pytree_registry=tree_util.default_registry)
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
|
@ -834,8 +834,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
if not candidates_list:
|
||||
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
||||
# TODO(skye): more efficient reshard?
|
||||
return pxla.shard_arg(x._value, devices, indices, sharding,
|
||||
canonicalize=False)
|
||||
return pxla.shard_arg(x._value, sharding, canonicalize=False)
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
@ -848,10 +847,11 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices, sharding):
|
||||
def _array_shard_arg(x, sharding):
|
||||
x._check_if_deleted()
|
||||
|
||||
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
|
||||
indices = sharding.addressable_devices_indices_map(x.shape).values()
|
||||
if not x.is_fully_addressable:
|
||||
if tuple(x_indices) == tuple(indices):
|
||||
return x
|
||||
@ -859,16 +859,15 @@ def _array_shard_arg(x, devices, indices, sharding):
|
||||
raise NotImplementedError(
|
||||
"Cannot reshard an input that is not fully addressable")
|
||||
else:
|
||||
devices = pxla.get_addressable_devices_for_shard_arg(sharding)
|
||||
if tuple(x_indices) == tuple(indices):
|
||||
return xc.copy_array_to_devices_with_sharding(
|
||||
x, list(devices), sharding)
|
||||
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
|
||||
# Resharding starts here:
|
||||
if dispatch.is_single_device_sharding(x.sharding):
|
||||
return shard_device_array(x, devices, indices, sharding)
|
||||
else:
|
||||
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
|
@ -124,12 +124,8 @@ class RuntimeTokenSet(threading.local):
|
||||
def get_token_input(self, eff: core.Effect,
|
||||
devices: list[Device]) -> jax.Array:
|
||||
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
||||
s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]),
|
||||
PartitionSpec([]))
|
||||
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
||||
indices = tuple(
|
||||
s.addressable_devices_indices_map(tok.shape).values())
|
||||
sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0]
|
||||
sharded_tok = pxla.shard_args([s], [tok])[0]
|
||||
self.current_tokens[eff] = sharded_tok
|
||||
return sharded_tok
|
||||
|
||||
@ -331,8 +327,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
|
||||
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False)
|
||||
map_ = s.devices_indices_map(aval.shape) # type: ignore
|
||||
return result_handler(pxla.shard_arg(x, list(map_), list(map_.values()), s))
|
||||
return result_handler(pxla.shard_arg(x, s))
|
||||
|
||||
def _override_get_device_assignment(sharding, *args, **kwargs):
|
||||
da = sharding._device_assignment
|
||||
|
@ -106,60 +106,46 @@ ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
|
||||
"""Returns a list of size len(devices) containing per-device buffers.
|
||||
|
||||
For the C++ pmap path, we fallback to Python (this function) to shard
|
||||
arguments that are not supported by the C++ `ShardArg`.
|
||||
|
||||
Args:
|
||||
arg: The Python argument.
|
||||
devices: The list of devices to shard over.
|
||||
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
||||
"""
|
||||
def shard_arg(arg, sharding, canonicalize=True):
|
||||
if canonicalize:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
|
||||
return shard_arg_handlers[type(arg)](arg, sharding)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def shard_args(
|
||||
devices: Sequence[xb.xla_client.Device],
|
||||
indices: Sequence[Sequence[Index]],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
args,
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
|
||||
) -> Sequence[jax.Array]:
|
||||
"""Shard each argument data array along its leading axis.
|
||||
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
|
||||
|
||||
Args:
|
||||
devices: sequence of Devices mapping replica index to a physical device.
|
||||
indices: sequence of the same length as `args` describing how each arg
|
||||
should be sharded/replicated across `devices`. Each element in `indices`
|
||||
is the same length as `devices`.
|
||||
args: a sequence of JaxTypes representing arguments to be sharded according
|
||||
to `indices` and placed on `devices`.
|
||||
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
|
||||
|
||||
Returns:
|
||||
A list of length matching args, containing lists of per-device buffers
|
||||
for each argument.
|
||||
"""
|
||||
return [shard_arg(arg, devices, indices[i], shardings[i])
|
||||
for i, arg in enumerate(args)]
|
||||
|
||||
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
|
||||
@lru_cache(maxsize=1024)
|
||||
def get_addressable_devices_for_shard_arg(
|
||||
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
|
||||
return s._addressable_device_assignment
|
||||
|
||||
def _shard_token(x, devices, indices, sharding):
|
||||
@lru_cache(maxsize=1024)
|
||||
def _get_replicated_slices(num_addressable_devices: int):
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
|
||||
def _shard_token(x, sharding):
|
||||
devices = get_addressable_devices_for_shard_arg(sharding)
|
||||
indices = _get_replicated_slices(len(devices))
|
||||
zeros = np.zeros((), dtype=np.dtype(np.bool_))
|
||||
aval = api_util.shaped_abstractify(zeros)
|
||||
return batched_device_put(aval, sharding, [zeros for i in indices], devices)
|
||||
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
|
||||
shard_arg_handlers[core.Token] = _shard_token
|
||||
|
||||
def _masked_array_error(x, devices, indices, sharding):
|
||||
def _masked_array_error(x, sharding):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
|
||||
|
||||
def _shard_array(x, devices, indices, sharding):
|
||||
def _shard_array(x, sharding):
|
||||
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
|
||||
devices = get_addressable_devices_for_shard_arg(sharding)
|
||||
if x.dtype == dtypes.float0:
|
||||
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
||||
aval = api_util.shaped_abstractify(x)
|
||||
@ -167,8 +153,8 @@ def _shard_array(x, devices, indices, sharding):
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
|
||||
def _shard_darray(x, devices, indices, sharding):
|
||||
return shard_arg(x._data, devices, indices, sharding)
|
||||
def _shard_darray(x, sharding):
|
||||
return shard_arg(x._data, sharding)
|
||||
shard_arg_handlers[core.DArray] = _shard_darray
|
||||
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
@ -183,7 +169,7 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore
|
||||
|
||||
def shard_aval(size, axis: int, aval):
|
||||
try:
|
||||
@ -849,8 +835,8 @@ class UnloadedPmapExecutable:
|
||||
if spec.sharding_spec is not None else None)
|
||||
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
||||
self.output_shardings)
|
||||
handle_args = InputsHandler(self.compiled.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
handle_args = InputsHandler(self.input_shardings,
|
||||
self.compiled.local_devices(), input_indices)
|
||||
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
||||
self.backend, handle_args, handle_outs,
|
||||
self.unordered_effects,
|
||||
@ -1054,9 +1040,8 @@ def _get_pmap_sharding(devices, specs):
|
||||
class InputsHandler:
|
||||
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
||||
|
||||
def __init__(self, local_devices, in_shardings, input_indices):
|
||||
self.handler = partial(
|
||||
shard_args, local_devices, input_indices, in_shardings)
|
||||
def __init__(self, in_shardings, local_devices=None, input_indices=None):
|
||||
self.handler = partial(shard_args, in_shardings)
|
||||
self.local_devices = local_devices
|
||||
self.in_shardings = in_shardings
|
||||
self.input_indices = input_indices
|
||||
@ -2248,37 +2233,36 @@ class MeshComputation(stages.XlaLowering):
|
||||
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
||||
if ndim is None:
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
else:
|
||||
return ((slice(None),) * ndim,) * num_addressable_devices
|
||||
if xla_extension_version < 229:
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
) -> Sequence[tuple[Index | None, ...]]:
|
||||
|
||||
input_indices = []
|
||||
if not isinstance(da_object, _DeviceAssignment):
|
||||
da_object = _create_da_object(tuple(da_object))
|
||||
num_addressable_devices = len(da_object.addressable_device_list)
|
||||
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
) -> Sequence[tuple[Index | None, ...]]:
|
||||
|
||||
input_indices = []
|
||||
if not isinstance(da_object, _DeviceAssignment):
|
||||
da_object = _create_da_object(tuple(da_object))
|
||||
num_addressable_devices = len(da_object.addressable_device_list)
|
||||
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
index = _get_replicated_slices(num_addressable_devices, None)
|
||||
else:
|
||||
if sharding.is_fully_replicated:
|
||||
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
|
||||
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
||||
if ndim is None:
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
else:
|
||||
index = tuple(
|
||||
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
||||
input_indices.append(index)
|
||||
return ((slice(None),) * ndim,) * num_addressable_devices
|
||||
|
||||
return input_indices
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
index = _get_replicated_slices(num_addressable_devices, None)
|
||||
else:
|
||||
if sharding.is_fully_replicated:
|
||||
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
|
||||
else:
|
||||
index = tuple(
|
||||
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
||||
input_indices.append(index)
|
||||
|
||||
return input_indices
|
||||
|
||||
|
||||
def get_gspmd_shardings_from_executable(
|
||||
@ -2604,10 +2588,13 @@ class UnloadedMeshExecutable:
|
||||
all_args_info: AllArgsInfo | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
self.device_assignment)
|
||||
handle_args = InputsHandler(self.xla_executable.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
if xla_extension_version >= 229:
|
||||
handle_args = InputsHandler(self.input_shardings)
|
||||
else:
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
self.device_assignment)
|
||||
handle_args = InputsHandler(
|
||||
self.input_shardings, self.xla_executable.local_devices(), input_indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed,
|
||||
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
@ -2755,6 +2742,7 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
out_avals: Sequence[ShapedArray]
|
||||
out_committed: Sequence[bool]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
||||
arg_handler_devices: Sequence[xc.Device]
|
||||
arg_handler_indices: Sequence[tuple[Index | None, ...]]
|
||||
|
||||
@ -2865,13 +2853,20 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return outs, fastpath_data
|
||||
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, shard_arg)
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry,
|
||||
shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
|
||||
tree_util.dispatch_registry)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
||||
def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
|
||||
return shard_arg(arg, sharding)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
@ -2926,20 +2921,15 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
out_shardings = semantics_out_shardings.shardings
|
||||
|
||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
|
||||
if pmap_nreps > 1:
|
||||
# For a jit wrapping a pmap, replicate each input index to match the
|
||||
# devices of the replicated jit computation.
|
||||
input_indices = [index * pmap_nreps for index in input_indices]
|
||||
kept_var_idx = set(kept_var_idx)
|
||||
# Will compute out_handler with executable information.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=False, name=name, computation=computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=global_in_avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=kept_var_idx,
|
||||
device_assignment=da, ordered_effects=ordered_effects,
|
||||
in_avals=global_in_avals,
|
||||
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
|
||||
out_avals=global_out_avals, out_shardings=out_shardings,
|
||||
committed=committed, pmap_nreps=pmap_nreps)
|
||||
xla_executable = None
|
||||
|
@ -234,7 +234,8 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
@ -1348,9 +1349,11 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry, pxla.shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
else:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
|
||||
tree_util.dispatch_registry,
|
||||
|
@ -636,20 +636,11 @@ xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
|
||||
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
|
||||
aval = x.aval
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
|
||||
arr = x._base_array
|
||||
|
||||
# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
|
||||
# sharded. This is only true when enable_custom_prng is True.
|
||||
trailing_inds = [slice(None)] * len(key_shape)
|
||||
phys_indices = [(*inds, *trailing_inds) for inds in indices]
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
aval, sharding, is_sharding_from_xla=False)
|
||||
return pxla.shard_arg_handlers[type(arr)](
|
||||
arr, devices, phys_indices, phys_sharding
|
||||
)
|
||||
x.aval, sharding, is_sharding_from_xla=False)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
|
||||
|
@ -30,9 +30,13 @@ XLADeviceAssignment = Sequence[Device]
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _addressable_devices_indices_map(
|
||||
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
|
||||
global_map = sharding.devices_indices_map(global_shape)
|
||||
if sharding.is_fully_addressable:
|
||||
return sharding.devices_indices_map(global_shape)
|
||||
return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items()
|
||||
return global_map
|
||||
if hasattr(sharding, '_internal_device_list'):
|
||||
return {d: global_map[d]
|
||||
for d in sharding._internal_device_list.addressable_device_list}
|
||||
return {d: ind for d, ind in global_map.items()
|
||||
if d.process_index == d.client.process_index()}
|
||||
|
||||
|
||||
|
@ -110,6 +110,10 @@ class XLACompatibleSharding(sharding.Sharding):
|
||||
|
||||
@functools.cached_property
|
||||
def _addressable_device_assignment(self) -> XLADeviceAssignment:
|
||||
if self.is_fully_addressable:
|
||||
return self._device_assignment
|
||||
if hasattr(self, '_internal_device_list'):
|
||||
return tuple(self._internal_device_list.addressable_device_list)
|
||||
return tuple(d for d in self._device_assignment
|
||||
if d.process_index == d.client.process_index())
|
||||
|
||||
|
@ -3042,8 +3042,8 @@ class FooArray:
|
||||
size = property(lambda self: self.data.size // 2)
|
||||
ndim = property(lambda self: self.data.ndim - 1)
|
||||
|
||||
def shard_foo_array_handler(x, devices, indices, sharding):
|
||||
device, = devices
|
||||
def shard_foo_array_handler(x, sharding):
|
||||
device, = sharding._addressable_device_assignment
|
||||
aval = core.raise_to_shaped(core.get_aval(x.data))
|
||||
return pxla.batched_device_put(
|
||||
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
|
||||
|
@ -45,7 +45,7 @@ from jax.experimental.maps import xmap
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding, _addressable_devices_indices_map
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.sharding_impls import (
|
||||
@ -60,7 +60,7 @@ from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
from jax._src.util import curry, unzip2
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -3546,32 +3546,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out4.devices(), {jax.devices()[1]})
|
||||
|
||||
def test_get_indices_cache(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
ns = NamedSharding(mesh, P('x'))
|
||||
ns2 = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp, ns)
|
||||
arr2 = jax.device_put(np_inp, ns2)
|
||||
arr3 = jax.device_put(np_inp, ns)
|
||||
|
||||
_addressable_devices_indices_map.cache_clear()
|
||||
|
||||
cache_info1 = _addressable_devices_indices_map.cache_info()
|
||||
out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3)
|
||||
cache_info2 = _addressable_devices_indices_map.cache_info()
|
||||
self.assertArraysEqual(out, np_inp * 3)
|
||||
|
||||
# arr3 and arr1 should have the same GSPMDSharding objects internally.
|
||||
# So there will be 2 hits in _addressable_devices_indices_map,
|
||||
# One in `pxla._get_input_indices` and second in `_array_shard_arg`.
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 2)
|
||||
# There will double the amount of misses as hits because arr1 and arr2's
|
||||
# sharding are not the same. So 2 misses in _addressable_devices_indices_map
|
||||
# and 2 in _array_shard_arg.
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses + 4)
|
||||
|
||||
def test_same_named_sharding_pspec_on_eager_ops(self):
|
||||
mesh = jtu.create_global_mesh((1, 8, 1), ('x', 'y', 'z'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z'))
|
||||
@ -4261,26 +4235,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
|
||||
)
|
||||
|
||||
def test_get_input_indices_fully_replicated(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_in_aval1 = core.ShapedArray((4, 4), jnp.int32)
|
||||
global_in_aval2 = core.ShapedArray((4, 4, 4), jnp.int32)
|
||||
global_in_aval3 = core.ShapedArray((), jnp.int32)
|
||||
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
|
||||
|
||||
mp = NamedSharding(global_mesh, P(None))
|
||||
|
||||
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp],
|
||||
list(global_mesh.devices.flat))
|
||||
|
||||
self.assertLen(out_indices, len(in_avals))
|
||||
self.assertTrue(all(len(out) == len(global_mesh.local_devices)
|
||||
for out in out_indices))
|
||||
self.assertTrue(all(len(i) == aval.ndim
|
||||
for out, aval in safe_zip(out_indices, in_avals) for i in out))
|
||||
self.assertTrue(all(i == (slice(None),) * aval.ndim
|
||||
for out, aval in safe_zip(out_indices, in_avals) for i in out))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
|
||||
("only_unspecified", UNSPECIFIED),
|
||||
|
@ -3006,9 +3006,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
|
||||
results = pxla.shard_args(
|
||||
jax.devices()[:nshards], [indices], [sharding], [arg]
|
||||
)
|
||||
results = pxla.shard_args([sharding], [arg])
|
||||
self.assertEqual(len(results), 1)
|
||||
if isinstance(results[0], array.ArrayImpl):
|
||||
bufs = results[0]._arrays
|
||||
|
Loading…
x
Reference in New Issue
Block a user