From fc6df3218ca8f0e43a3ab447d22adccec7407d78 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 19 Jan 2024 03:53:01 -0800 Subject: [PATCH] Add a new experimental option jax_pmap_no_rank_reduction. This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis. i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead. Why do this? The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA. The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design. This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths. Once enabled, this change has the potential to break pmap users who: a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change. b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`. The change is disabled by default, so we do not expect any user visible impacts from this change. PiperOrigin-RevId: 599787818 --- jax/BUILD | 1 + jax/_src/api.py | 20 +++++++++- jax/_src/array.py | 32 +++++++++++---- jax/_src/config.py | 8 ++++ jax/_src/interpreters/pxla.py | 74 +++++++++++++++++++++++++++++------ jax/_src/lax/lax.py | 2 +- jax/_src/sharding_impls.py | 20 +++++++++- jax/_src/sharding_specs.py | 15 +++++-- jax/_src/util.py | 4 ++ tests/pmap_test.py | 7 +++- 10 files changed, 154 insertions(+), 29 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 8f913c1d2..d7f163070 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -726,6 +726,7 @@ pytype_strict_library( name = "sharding_specs", srcs = ["_src/sharding_specs.py"], deps = [ + ":config", ":op_shardings", ":util", "//jax/_src/lib", diff --git a/jax/_src/api.py b/jax/_src/api.py index 418fe3dfd..edfe052c0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -46,6 +46,7 @@ from jax._src import core from jax._src import dispatch from jax._src import effects from jax._src import array +from jax._src import basearray from jax._src import dtypes from jax._src import sharding_impls from jax._src import sharding_specs @@ -1641,6 +1642,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str, for pi in range(xb.process_count(backend))) return global_axis_size + def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, in_devices, backend_name, axis_size, args, kwargs): @@ -2603,7 +2605,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended): return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices) - return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices)) + if config.pmap_no_rank_reduction.value: + ys = [] + for x in xs: + if not isinstance(x, (np.ndarray, basearray.Array)): + x = np.asarray(x) + ys.append(x[None]) + else: + ys = xs + return pxla.batched_device_put(stacked_aval, sharding, ys, list(devices)) with config.explicit_device_put_scope(): @@ -2649,7 +2659,13 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 core.raise_to_shaped(core.get_aval(x))) assert isinstance(aval, ShapedArray) sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) - buf = device_put(x, devices[0]) + if config.pmap_no_rank_reduction.value: + if isinstance(x, (np.ndarray, basearray.Array)): + buf = device_put(x[None], devices[0]) + else: + buf = device_put(x, devices[0])[None] + else: + buf = device_put(x, devices[0]) sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) diff --git a/jax/_src/array.py b/jax/_src/array.py index fb82ffe92..09c722a8c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -303,6 +303,7 @@ class ArrayImpl(basearray.Array): return format(self._value, format_spec) def __getitem__(self, idx): + from jax._src.lax import lax from jax._src.numpy import lax_numpy self._check_if_deleted() @@ -314,23 +315,38 @@ class ArrayImpl(basearray.Array): f"was indexed with {num_idx} non-None/Ellipsis indices.") if isinstance(self.sharding, PmapSharding): - if not isinstance(idx, tuple): - cidx = (idx,) + (slice(None),) * (len(self.shape) - 1) + if config.pmap_no_rank_reduction.value: + cidx = idx if isinstance(idx, tuple) else (idx,) + + padded_cidx = tuple( + slice(i, i + 1, None) if isinstance(i, int) else i for i in cidx + ) + (slice(None),) * (len(self.shape) - len(cidx)) else: - cidx = idx + (slice(None),) * (len(self.shape) - len(idx)) + if not isinstance(idx, tuple): + padded_cidx = (idx,) + (slice(None),) * (len(self.shape) - 1) + else: + padded_cidx = idx + (slice(None),) * (len(self.shape) - len(idx)) + indices = tuple(self.sharding.devices_indices_map(self.shape).values()) try: - arr_idx = indices.index(cidx) + arr_idx = indices.index(padded_cidx) except ValueError: arr_idx = None if arr_idx is not None: a = self._arrays[arr_idx] - return ArrayImpl( + out = ArrayImpl( a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False, _skip_checks=True) - return lax_numpy._rewriting_take(self, idx) - else: - return lax_numpy._rewriting_take(self, idx) + + if config.pmap_no_rank_reduction.value: + # If cidx was the index of a single shard, then it corresponds to one + # shard of the chunked dimension. + dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) + return lax.squeeze(out, dimensions=dims) + else: + return out + + return lax_numpy._rewriting_take(self, idx) def __iter__(self): if self.ndim == 0: diff --git a/jax/_src/config.py b/jax/_src/config.py index 38172f214..455111195 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1422,3 +1422,11 @@ define_string_state( '"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging ' 'for.'), update_global_hook=_update_debug_log_modules) + +pmap_no_rank_reduction = define_bool_state( + name='jax_pmap_no_rank_reduction', + default=False, + help=( + "If True, pmap shards have a the same rank as their enclosing array." + ) +) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 321d5a4c8..ebfa5bf7c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -72,7 +72,8 @@ from jax._src.sharding_impls import ( is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources ) from jax._src.util import (safe_map, safe_zip, partition_list, - wrap_name, tuple_delete, distributed_debug_log, + wrap_name, tuple_update, tuple_delete, + distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) @@ -171,12 +172,13 @@ def batched_device_put(aval: core.ShapedArray, aval, sharding, bufs, committed=committed, _skip_checks=True) return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore -def shard_aval(size, axis: int, aval): +def _shard_aval(size, axis: int, aval): try: - return shard_aval_handlers[type(aval)](size, axis, aval) + return _shard_aval_handlers[type(aval)](size, axis, aval) except KeyError as err: - raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err -shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {} + raise TypeError(f"No _shard_aval handler for type: {type(aval)}") from err +_shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {} + def _shard_abstract_array(size, axis: int, x): try: if x.shape[axis] != size: @@ -184,8 +186,11 @@ def _shard_abstract_array(size, axis: int, x): f"shape {x.shape}") except IndexError: raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None - return x.update(shape=tuple_delete(x.shape, axis)) -shard_aval_handlers[ShapedArray] = _shard_abstract_array + if config.pmap_no_rank_reduction.value: + return x.update(shape=tuple_update(x.shape, axis, 1)) + else: + return x.update(shape=tuple_delete(x.shape, axis)) +_shard_aval_handlers[ShapedArray] = _shard_abstract_array def local_aval_to_result_handler( @@ -620,21 +625,39 @@ def find_replicas( num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) +@lu.transformation +def _change_argument_ranks(in_axes, out_axes_thunk, *args): + args = tuple( + arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) + for in_axis, arg in zip(in_axes, args) + ) + results = yield (args, {}) + out_axes = out_axes_thunk() + yield tuple( + x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) + for x, axis in zip(results, out_axes) + ) + def stage_parallel_callable( pci: ParallelCallableInfo, fun: lu.WrappedFun ) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]: sharded_avals = tuple( - shard_aval(pci.axis_size, axis, aval) if axis is not None else aval + _shard_aval(pci.axis_size, axis, aval) if axis is not None else aval for axis, aval in safe_zip(pci.in_axes, pci.avals)) + orig_fun = fun + if config.pmap_no_rank_reduction.value: + fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) + else: + fun = orig_fun with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) - jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info) + jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) assert len(out_sharded_avals) == len(pci.out_axes), ( @@ -666,7 +689,7 @@ def lower_parallel_callable( is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters): + lowering_parameters: mlir.LoweringParameters) -> PmapComputation: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -782,6 +805,35 @@ def lower_parallel_callable( jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) +def _pmap_unmap_shaped_array( + size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray + ) -> ShapedArray: + named_shape = dict(aval.named_shape) + named_shape.pop(axis_name, None) # TODO: make this mandatory + if axis is None: return aval.update(named_shape=named_shape) + elif type(axis) is int: + return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype, + named_shape=named_shape, weak_type=aval.weak_type) + else: raise TypeError(axis) + + +AvalMapHandlerPair = tuple[Any, Callable] +_pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { + ShapedArray: (Any, _pmap_unmap_shaped_array), +} + +def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None, + aval: core.AbstractValue) -> core.AbstractValue: + if not config.pmap_no_rank_reduction.value: + return core.unmapped_aval(size, axis_name, axis, aval) + + _, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None)) + if handler is not None: + return handler(size, axis_name, axis, aval) + else: + raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") + + class PmapComputation(stages.XlaLowering): _hlo: ir.Module _executable: PmapExecutable | None @@ -938,7 +990,7 @@ class UnloadedPmapExecutable: local_unmapped_avals = [ _cast_to_shaped_array( - core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)) + _pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)) if out_axis is not None else aval for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)] out_specs = [ diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 682b7599b..1bd51e5ae 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4547,7 +4547,7 @@ def _array_copy(arr: ArrayLike) -> Array: def _which_dim_sharded(s: PmapSharding) -> int | None: sharded_dim = None for i, s in enumerate(s.sharding_spec.sharding): - if isinstance(s, pxla.Unstacked): + if isinstance(s, (pxla.Unstacked, pxla.Chunked)): sharded_dim = i break return sharded_dim diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 5e87f4ca3..d648a0d2b 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -541,7 +541,16 @@ class PmapSharding(XLACompatibleSharding): num_ways_sharded = None for s in sharding_spec.sharding: if isinstance(s, sharding_specs.Unstacked): + assert num_ways_sharded is None num_ways_sharded = s.size + elif isinstance(s, sharding_specs.Chunked): + assert num_ways_sharded is None + if len(s.chunks) == 1: + num_ways_sharded = s.chunks[0] + else: + raise NotImplementedError( + 'Multiple chunks in Chunked dimension not supported.') + if num_ways_sharded is None: raise NotImplementedError( '`None` to sharded_dim is not supported. Please file a jax ' @@ -581,7 +590,7 @@ class PmapSharding(XLACompatibleSharding): @functools.cached_property def is_fully_replicated(self) -> bool: for s in self.sharding_spec.sharding: - if isinstance(s, sharding_specs.Unstacked): + if isinstance(s, (sharding_specs.Unstacked, sharding_specs.Chunked)): return False return True @@ -596,6 +605,13 @@ class PmapSharding(XLACompatibleSharding): if isinstance(s, sharding_specs.Unstacked): sharded_dim = i sharded_dim_size = s.size + sharded_shape = util.tuple_delete(global_shape, sharded_dim) + break + elif isinstance(s, sharding_specs.Chunked): + sharded_dim = i + assert len(s.chunks) == 1, s.chunks + sharded_dim_size = s.chunks[0] + sharded_shape = util.tuple_update(global_shape, sharded_dim, 1) break if sharded_dim is None: return global_shape @@ -605,7 +621,7 @@ class PmapSharding(XLACompatibleSharding): f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} ' f'with value {global_shape[sharded_dim]} in shape {global_shape} and ' f'the number of devices={len(self._device_assignment)}') - return global_shape[:sharded_dim] + global_shape[sharded_dim+1:] + return sharded_shape def _op_sharding_to_pos_sharding( diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 46439c610..02d6276a5 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -36,6 +36,7 @@ from typing import Union import numpy as np +from jax._src import config from jax._src import util from jax._src.lib import pmap_lib @@ -188,9 +189,14 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], return a # replication_factor represents the product of inner pmaps, so it goes # after the outer pmapped axis at index 0 + if config.pmap_no_rank_reduction.value: + sharding = util.tuple_update( + pspec.sharding, map_axis, Chunked([axis_size])) + else: + sharding = util.tuple_insert( + pspec.sharding, map_axis, Unstacked(axis_size)) return ShardingSpec( - sharding=util.tuple_insert( - pspec.sharding, map_axis, Unstacked(axis_size)), + sharding=sharding, mesh_mapping=itertools.chain( [ShardedAxis(sharded_in_axis)], maybe_replicate, map(shift_sharded_axis, pspec.mesh_mapping))) @@ -203,7 +209,10 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0, sharded_dim_size: int | None = None): if sharded_dim is not None: - sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:] + if config.pmap_no_rank_reduction.value: + sharded_shape = util.tuple_update(shape, sharded_dim, 1) + else: + sharded_shape = util.tuple_delete(shape, sharded_dim) if sharded_dim_size is None: sharded_dim_size = shape[sharded_dim] else: diff --git a/jax/_src/util.py b/jax/_src/util.py index f492da141..1e3ccc38f 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -414,6 +414,10 @@ def tuple_delete(t, idx): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + t[idx + 1:] +def tuple_update(t, idx, val): + assert 0 <= idx < len(t), (idx, len(t)) + return t[:idx] + (val,) + t[idx+1:] + class HashableFunction: """Decouples function equality and hash from its identity. diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 4b7f9a96b..a1153cefe 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3073,7 +3073,10 @@ class ArrayPmapTest(jtu.JaxTestCase): self.assertEqual(out2.shape, (dc, dc, 2)) for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)): self.assertArraysEqual(s1.data, input_data[i]) - self.assertArraysEqual(s2.data, input_data) + if config.pmap_no_rank_reduction.value: + self.assertArraysEqual(s2.data, input_data[None]) + else: + self.assertArraysEqual(s2.data, input_data) def test_pmap_array_sharding_mismatch(self): input_shape = (jax.device_count(), 2) @@ -3105,7 +3108,7 @@ class ArrayPmapTest(jtu.JaxTestCase): def amap(f, xs): ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs] - return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys]) + return jax.device_put_sharded(ys, jax.local_devices()[:2]) # leading axis is batch dim (i.e. mapped/parallel dim), of size 2 x = jnp.array([[1., 0., 0.],