mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis. i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead. Why do this? The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA. The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design. This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths. Once enabled, this change has the potential to break pmap users who: a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change. b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`. The change is disabled by default, so we do not expect any user visible impacts from this change. PiperOrigin-RevId: 599787818
This commit is contained in:
parent
f04f305489
commit
fc6df3218c
@ -726,6 +726,7 @@ pytype_strict_library(
|
||||
name = "sharding_specs",
|
||||
srcs = ["_src/sharding_specs.py"],
|
||||
deps = [
|
||||
":config",
|
||||
":op_shardings",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
|
@ -46,6 +46,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import array
|
||||
from jax._src import basearray
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
@ -1641,6 +1642,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
|
||||
for pi in range(xb.process_count(backend)))
|
||||
return global_axis_size
|
||||
|
||||
|
||||
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, in_devices, backend_name,
|
||||
axis_size, args, kwargs):
|
||||
@ -2603,7 +2605,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
sharding = PmapSharding(np.array(devices), sharding_spec)
|
||||
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
|
||||
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
|
||||
return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices))
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
ys = []
|
||||
for x in xs:
|
||||
if not isinstance(x, (np.ndarray, basearray.Array)):
|
||||
x = np.asarray(x)
|
||||
ys.append(x[None])
|
||||
else:
|
||||
ys = xs
|
||||
return pxla.batched_device_put(stacked_aval, sharding, ys, list(devices))
|
||||
|
||||
|
||||
with config.explicit_device_put_scope():
|
||||
@ -2649,7 +2659,13 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
assert isinstance(aval, ShapedArray)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
buf = device_put(x, devices[0])
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
if isinstance(x, (np.ndarray, basearray.Array)):
|
||||
buf = device_put(x[None], devices[0])
|
||||
else:
|
||||
buf = device_put(x, devices[0])[None]
|
||||
else:
|
||||
buf = device_put(x, devices[0])
|
||||
sharding = PmapSharding(np.array(devices), sharding_spec)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
|
||||
|
@ -303,6 +303,7 @@ class ArrayImpl(basearray.Array):
|
||||
return format(self._value, format_spec)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy import lax_numpy
|
||||
self._check_if_deleted()
|
||||
|
||||
@ -314,23 +315,38 @@ class ArrayImpl(basearray.Array):
|
||||
f"was indexed with {num_idx} non-None/Ellipsis indices.")
|
||||
|
||||
if isinstance(self.sharding, PmapSharding):
|
||||
if not isinstance(idx, tuple):
|
||||
cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
cidx = idx if isinstance(idx, tuple) else (idx,)
|
||||
|
||||
padded_cidx = tuple(
|
||||
slice(i, i + 1, None) if isinstance(i, int) else i for i in cidx
|
||||
) + (slice(None),) * (len(self.shape) - len(cidx))
|
||||
else:
|
||||
cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
|
||||
if not isinstance(idx, tuple):
|
||||
padded_cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
|
||||
else:
|
||||
padded_cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
|
||||
|
||||
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
|
||||
try:
|
||||
arr_idx = indices.index(cidx)
|
||||
arr_idx = indices.index(padded_cidx)
|
||||
except ValueError:
|
||||
arr_idx = None
|
||||
if arr_idx is not None:
|
||||
a = self._arrays[arr_idx]
|
||||
return ArrayImpl(
|
||||
out = ArrayImpl(
|
||||
a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
|
||||
_skip_checks=True)
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
# If cidx was the index of a single shard, then it corresponds to one
|
||||
# shard of the chunked dimension.
|
||||
dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int))
|
||||
return lax.squeeze(out, dimensions=dims)
|
||||
else:
|
||||
return out
|
||||
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0:
|
||||
|
@ -1422,3 +1422,11 @@ define_string_state(
|
||||
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
|
||||
'for.'),
|
||||
update_global_hook=_update_debug_log_modules)
|
||||
|
||||
pmap_no_rank_reduction = define_bool_state(
|
||||
name='jax_pmap_no_rank_reduction',
|
||||
default=False,
|
||||
help=(
|
||||
"If True, pmap shards have a the same rank as their enclosing array."
|
||||
)
|
||||
)
|
||||
|
@ -72,7 +72,8 @@ from jax._src.sharding_impls import (
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
wrap_name, tuple_update, tuple_delete,
|
||||
distributed_debug_log,
|
||||
unzip2, HashableFunction, weakref_lru_cache)
|
||||
|
||||
|
||||
@ -171,12 +172,13 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore
|
||||
|
||||
def shard_aval(size, axis: int, aval):
|
||||
def _shard_aval(size, axis: int, aval):
|
||||
try:
|
||||
return shard_aval_handlers[type(aval)](size, axis, aval)
|
||||
return _shard_aval_handlers[type(aval)](size, axis, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
|
||||
shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
||||
raise TypeError(f"No _shard_aval handler for type: {type(aval)}") from err
|
||||
_shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
||||
|
||||
def _shard_abstract_array(size, axis: int, x):
|
||||
try:
|
||||
if x.shape[axis] != size:
|
||||
@ -184,8 +186,11 @@ def _shard_abstract_array(size, axis: int, x):
|
||||
f"shape {x.shape}")
|
||||
except IndexError:
|
||||
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
|
||||
return x.update(shape=tuple_delete(x.shape, axis))
|
||||
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
return x.update(shape=tuple_update(x.shape, axis, 1))
|
||||
else:
|
||||
return x.update(shape=tuple_delete(x.shape, axis))
|
||||
_shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
@ -620,21 +625,39 @@ def find_replicas(
|
||||
num_global_replicas = global_axis_size * jaxpr_replicas
|
||||
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
|
||||
|
||||
@lu.transformation
|
||||
def _change_argument_ranks(in_axes, out_axes_thunk, *args):
|
||||
args = tuple(
|
||||
arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,))
|
||||
for in_axis, arg in zip(in_axes, args)
|
||||
)
|
||||
results = yield (args, {})
|
||||
out_axes = out_axes_thunk()
|
||||
yield tuple(
|
||||
x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,))
|
||||
for x, axis in zip(results, out_axes)
|
||||
)
|
||||
|
||||
|
||||
def stage_parallel_callable(
|
||||
pci: ParallelCallableInfo, fun: lu.WrappedFun
|
||||
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
|
||||
sharded_avals = tuple(
|
||||
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
||||
_shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
||||
|
||||
orig_fun = fun
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk)
|
||||
else:
|
||||
fun = orig_fun
|
||||
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
@ -666,7 +689,7 @@ def lower_parallel_callable(
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -782,6 +805,35 @@ def lower_parallel_callable(
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
|
||||
def _pmap_unmap_shaped_array(
|
||||
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
named_shape = dict(aval.named_shape)
|
||||
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
||||
if axis is None: return aval.update(named_shape=named_shape)
|
||||
elif type(axis) is int:
|
||||
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
else: raise TypeError(axis)
|
||||
|
||||
|
||||
AvalMapHandlerPair = tuple[Any, Callable]
|
||||
_pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
ShapedArray: (Any, _pmap_unmap_shaped_array),
|
||||
}
|
||||
|
||||
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
|
||||
aval: core.AbstractValue) -> core.AbstractValue:
|
||||
if not config.pmap_no_rank_reduction.value:
|
||||
return core.unmapped_aval(size, axis_name, axis, aval)
|
||||
|
||||
_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
|
||||
if handler is not None:
|
||||
return handler(size, axis_name, axis, aval)
|
||||
else:
|
||||
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
|
||||
class PmapComputation(stages.XlaLowering):
|
||||
_hlo: ir.Module
|
||||
_executable: PmapExecutable | None
|
||||
@ -938,7 +990,7 @@ class UnloadedPmapExecutable:
|
||||
|
||||
local_unmapped_avals = [
|
||||
_cast_to_shaped_array(
|
||||
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
||||
_pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
|
@ -4547,7 +4547,7 @@ def _array_copy(arr: ArrayLike) -> Array:
|
||||
def _which_dim_sharded(s: PmapSharding) -> int | None:
|
||||
sharded_dim = None
|
||||
for i, s in enumerate(s.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
if isinstance(s, (pxla.Unstacked, pxla.Chunked)):
|
||||
sharded_dim = i
|
||||
break
|
||||
return sharded_dim
|
||||
|
@ -541,7 +541,16 @@ class PmapSharding(XLACompatibleSharding):
|
||||
num_ways_sharded = None
|
||||
for s in sharding_spec.sharding:
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
assert num_ways_sharded is None
|
||||
num_ways_sharded = s.size
|
||||
elif isinstance(s, sharding_specs.Chunked):
|
||||
assert num_ways_sharded is None
|
||||
if len(s.chunks) == 1:
|
||||
num_ways_sharded = s.chunks[0]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Multiple chunks in Chunked dimension not supported.')
|
||||
|
||||
if num_ways_sharded is None:
|
||||
raise NotImplementedError(
|
||||
'`None` to sharded_dim is not supported. Please file a jax '
|
||||
@ -581,7 +590,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
for s in self.sharding_spec.sharding:
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
if isinstance(s, (sharding_specs.Unstacked, sharding_specs.Chunked)):
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -596,6 +605,13 @@ class PmapSharding(XLACompatibleSharding):
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
sharded_dim = i
|
||||
sharded_dim_size = s.size
|
||||
sharded_shape = util.tuple_delete(global_shape, sharded_dim)
|
||||
break
|
||||
elif isinstance(s, sharding_specs.Chunked):
|
||||
sharded_dim = i
|
||||
assert len(s.chunks) == 1, s.chunks
|
||||
sharded_dim_size = s.chunks[0]
|
||||
sharded_shape = util.tuple_update(global_shape, sharded_dim, 1)
|
||||
break
|
||||
if sharded_dim is None:
|
||||
return global_shape
|
||||
@ -605,7 +621,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
|
||||
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
|
||||
f'the number of devices={len(self._device_assignment)}')
|
||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||
return sharded_shape
|
||||
|
||||
|
||||
def _op_sharding_to_pos_sharding(
|
||||
|
@ -36,6 +36,7 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import util
|
||||
from jax._src.lib import pmap_lib
|
||||
|
||||
@ -188,9 +189,14 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||||
return a
|
||||
# replication_factor represents the product of inner pmaps, so it goes
|
||||
# after the outer pmapped axis at index 0
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
sharding = util.tuple_update(
|
||||
pspec.sharding, map_axis, Chunked([axis_size]))
|
||||
else:
|
||||
sharding = util.tuple_insert(
|
||||
pspec.sharding, map_axis, Unstacked(axis_size))
|
||||
return ShardingSpec(
|
||||
sharding=util.tuple_insert(
|
||||
pspec.sharding, map_axis, Unstacked(axis_size)),
|
||||
sharding=sharding,
|
||||
mesh_mapping=itertools.chain(
|
||||
[ShardedAxis(sharded_in_axis)], maybe_replicate,
|
||||
map(shift_sharded_axis, pspec.mesh_mapping)))
|
||||
@ -203,7 +209,10 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||||
def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0,
|
||||
sharded_dim_size: int | None = None):
|
||||
if sharded_dim is not None:
|
||||
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
sharded_shape = util.tuple_update(shape, sharded_dim, 1)
|
||||
else:
|
||||
sharded_shape = util.tuple_delete(shape, sharded_dim)
|
||||
if sharded_dim_size is None:
|
||||
sharded_dim_size = shape[sharded_dim]
|
||||
else:
|
||||
|
@ -414,6 +414,10 @@ def tuple_delete(t, idx):
|
||||
assert 0 <= idx < len(t), (idx, len(t))
|
||||
return t[:idx] + t[idx + 1:]
|
||||
|
||||
def tuple_update(t, idx, val):
|
||||
assert 0 <= idx < len(t), (idx, len(t))
|
||||
return t[:idx] + (val,) + t[idx+1:]
|
||||
|
||||
class HashableFunction:
|
||||
"""Decouples function equality and hash from its identity.
|
||||
|
||||
|
@ -3073,7 +3073,10 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.shape, (dc, dc, 2))
|
||||
for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)):
|
||||
self.assertArraysEqual(s1.data, input_data[i])
|
||||
self.assertArraysEqual(s2.data, input_data)
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
self.assertArraysEqual(s2.data, input_data[None])
|
||||
else:
|
||||
self.assertArraysEqual(s2.data, input_data)
|
||||
|
||||
def test_pmap_array_sharding_mismatch(self):
|
||||
input_shape = (jax.device_count(), 2)
|
||||
@ -3105,7 +3108,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def amap(f, xs):
|
||||
ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs]
|
||||
return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys])
|
||||
return jax.device_put_sharded(ys, jax.local_devices()[:2])
|
||||
|
||||
# leading axis is batch dim (i.e. mapped/parallel dim), of size 2
|
||||
x = jnp.array([[1., 0., 0.],
|
||||
|
Loading…
x
Reference in New Issue
Block a user