From 33c4fc4fe2fd5160d65ed3a5b0937dee368eff74 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 10 Aug 2022 20:11:06 -0700 Subject: [PATCH] Pmap should output SDA like `Array`s to maintain the current behavior exactly. Split the shard_arg_handler for `Array` based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context. PiperOrigin-RevId: 466849614 --- jax/_src/api.py | 49 +++++----- jax/experimental/array.py | 26 +++-- jax/experimental/global_device_array.py | 4 +- jax/interpreters/pxla.py | 121 ++++++++++++------------ tests/BUILD | 5 + tests/pmap_test.py | 2 +- 6 files changed, 112 insertions(+), 95 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 57ca0e7a5..687f0d300 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1897,39 +1897,41 @@ class PmapCallInfo(NamedTuple): def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices): - from jax.experimental import sharding + from jax.experimental.sharding import PmapSharding + from jax.experimental.array import Array if not args: return - if in_devices is not None: - in_devices = np.array(in_devices) - - first_arr_devices = args[0].sharding.devices + first_device_assignment = None for a, i in safe_zip(args, in_axes_flat): - assert isinstance(a.sharding, sharding.PmapSharding) + if not isinstance(a, Array): + continue + if not isinstance(a.sharding, PmapSharding): + raise NotImplementedError('pmap only works with PmapSharding.') + if first_device_assignment is None: + first_device_assignment = a.sharding._device_assignment arr_sharding = a.sharding.sharded_dim - arr_devices = a.sharding.devices + arr_device_assignment = a.sharding._device_assignment if arr_sharding != i: raise ValueError('Array and pmap sharding does not match. Got pmap ' f'sharding: {i}, Array sharding: {arr_sharding} for ' f'arg: {a}') if (in_devices is not None and - arr_devices is not None and - not np.array_equal(arr_devices, in_devices)): + arr_device_assignment is not None and + arr_device_assignment != in_devices): raise ValueError('Devices passed to pmap and Array should be equal. ' - f'Got pmap devices: {devices}, Array devices: ' - f'{arr_devices} for arg: {a}') + f'Got pmap devices: {in_devices}, Array devices: ' + f'{arr_device_assignment} for arg: {a}') if (in_devices is None and - not np.array_equal(arr_devices, first_arr_devices)): + arr_device_assignment != first_device_assignment): raise ValueError('Devices of all `Array` inputs should be the same. ' - f'Got array device: {arr_devices}, ' - f'another array device: {first_arr_devices}') - return first_arr_devices + f'Got array device: {arr_device_assignment}, ' + f'another array device: {first_device_assignment}') def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, global_arg_shapes, devices, args, kwargs): + donate_tuple, global_arg_shapes, in_devices, args, kwargs): f = lu.wrap_init(fun) if static_broadcasted_tuple: if max(static_broadcasted_tuple) >= len(args): @@ -1971,13 +1973,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, flat_fun, out_tree = flatten_fun(f, in_tree) if config.jax_array: - from jax.experimental.array import Array - if any(not isinstance(a, Array) for a in args): - raise ValueError('All arguments to pmap when `config.jax_array` is ' - 'enabled should be `Array`s.') - arr_devices = _check_in_pmap_sharding_with_arrays(args, in_axes_flat, devices) - if devices is None and arr_devices is not None: - devices = arr_devices + _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices) if any(out_axis is None for out_axis in tree_flatten(out_axes)): raise NotImplementedError("None out_axes in pmap are not supported yet") @@ -2011,7 +2007,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, local_axis_size=local_axis_size, global_arg_shapes_flat=global_arg_shapes_flat, out_axes_thunk=out_axes_thunk, - devices=None if devices is None else tuple(devices)) + devices=None if in_devices is None else tuple(in_devices)) def _get_f_mapped( @@ -2199,8 +2195,9 @@ def _cpp_pmap( return out, fastpath_data - cpp_mapped_f = pmap_lib.pmap(fun, cache_miss, - static_broadcasted_tuple, pxla._shard_arg) + cpp_mapped_f = pmap_lib.pmap( + fun, cache_miss, static_broadcasted_tuple, + partial(pxla._shard_arg, mode=pxla.InputsHandlerMode.pmap)) pmap_f = wraps(fun)(cpp_mapped_f) diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 1b035ec55..07bac9f8e 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -21,7 +21,7 @@ from jax import core from jax._src import api_util from jax._src import dispatch from jax._src.config import config -from jax._src.util import prod +from jax._src.util import prod, safe_zip from jax._src.lib import xla_client as xc from jax._src.api import device_put from jax.interpreters import pxla, xla @@ -261,12 +261,26 @@ def _device_put_array(x, device: Optional[Device]): dispatch.device_put_handlers[Array] = _device_put_array -def _array_shard_arg(x, devices, indices): - return x._arrays +def _array_shard_arg(x, devices, indices, mode): + # TODO(yashkatariya): Remove the `mode` handling and try to consolidate the + # code paths. + if mode == pxla.InputsHandlerMode.pmap: + # sharding mismatch between `Array` and pmap sharding is checked in api.py's + # `_check_in_pmap_sharding_with_arrays` function. + return [buf if buf.device() == d else buf.copy_to_device(d) + for buf, d in safe_zip(x._arrays, devices)] + else: + return x._arrays pxla.shard_arg_handlers[Array] = _array_shard_arg -def _array_result_handler(global_aval, out_sharding): +def _array_global_result_handler(global_aval, out_sharding): return lambda bufs: Array(global_aval.shape, out_sharding, bufs, committed=True) -pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_result_handler -pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_result_handler +pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler +pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler + + +def _array_local_result_handler(aval, sharding, indices): + return lambda bufs: Array(aval.shape, sharding, bufs, committed=True) +pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler +pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index ed5898d25..d39ef3359 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -561,7 +561,9 @@ xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \ lambda x: core.ShapedArray(x.shape, x.dtype) -def _gda_shard_arg(x, devices, indices): +def _gda_shard_arg(x, devices, indices, mode): + if mode == pxla.InputsHandlerMode.pmap: + raise RuntimeError('GDA is not supported with pmap.') return x._device_buffers pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 47b75b08d..536079149 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -368,7 +368,7 @@ def spec_to_indices(shape: Tuple[int, ...], def identity(x): return x -def _shard_arg(arg, devices, arg_indices): +def _shard_arg(arg, devices, arg_indices, mode): """Returns a list of size len(devices) containing per-device buffers. For the C++ pmap path, we fallback to Python (this function) to shard @@ -378,6 +378,7 @@ def _shard_arg(arg, devices, arg_indices): arg: The Python argument. devices: The list of devices to shard over. arg_indices: A list of `len(devices)` indices to use to shard the argument. + mode: An enum telling whether shard_arg is executed via pmap or pjit/xmap. """ if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices: # The shard_arg_handlers allow an extensible set of types to be sharded, but @@ -390,12 +391,13 @@ def _shard_arg(arg, devices, arg_indices): ] else: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, devices, arg_indices) + return shard_arg_handlers[type(arg)](arg, devices, arg_indices, mode) @profiler.annotate_function def shard_args(devices: Sequence[xb.xla_client.Device], indices: Sequence[Sequence[Index]], + mode: InputsHandlerMode, args) -> Sequence[Sequence[xb.xla_client.Buffer]]: """Shard each argument data array along its leading axis. @@ -411,16 +413,16 @@ def shard_args(devices: Sequence[xb.xla_client.Device], A list of length matching args, containing lists of per-device buffers for each argument. """ - return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)] + return [_shard_arg(arg, devices, indices[i], mode) for i, arg in enumerate(args)] -shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {} -def _shard_array(x, devices, indices): +shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {} +def _shard_array(x, devices, indices, mode): return device_put([x[i] for i in indices], devices) for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_device_array(x, devices, indices): +def _shard_device_array(x, devices, indices, mode): start_indices, limit_indices, removed_dims = unzip3( _as_slice_indices(x, idx) for idx in indices) shards = x._multi_slice(start_indices, limit_indices, removed_dims) @@ -524,9 +526,15 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): return PartitionSpec(*partitions) +class OutputType(enum.Enum): + Array = 0 + GlobalDeviceArray = 1 + ShardedDeviceArray = 2 + + def local_aval_to_result_handler( aval: core.AbstractValue, - sharding_spec: Optional[ShardingSpec], + sharding: XLACompatibleSharding, indices: Optional[Tuple[Index]], ) -> Callable[[List[xb.xla_client.Buffer]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -543,24 +551,25 @@ def local_aval_to_result_handler( for this output. The function will return an object suitable for returning to the user, e.g. a ShardedDeviceArray. """ + if config.jax_array: + output_type = OutputType.Array + else: + output_type = OutputType.ShardedDeviceArray try: - return local_result_handlers[type(aval)](aval, sharding_spec, indices) + return local_result_handlers[(type(aval), output_type)](aval, sharding, indices) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]] -local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} -def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices): +local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {} + +def sda_array_result_handler(aval: ShapedArray, sharding, indices): + sharding_spec = _get_sharding_specs([sharding], [aval])[0] return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs, indices) -local_result_handlers[ShapedArray] = sda_array_result_handler -local_result_handlers[ConcreteArray] = sda_array_result_handler - - -class OutputType(enum.Enum): - Array = 0 - GlobalDeviceArray = 1 +local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler +local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler def global_aval_to_result_handler( @@ -839,7 +848,7 @@ def _hashable_index(idx): # The fast path is handled directly in shard_args(). # TODO(skye): is there a simpler way to rewrite this using sharding_spec? -def _shard_sharded_device_array_slow_path(x, devices, indices): +def _shard_sharded_device_array_slow_path(x, devices, indices, mode): candidates = defaultdict(list) for buf, idx in safe_zip(x.device_buffers, x.indices): candidates[_hashable_index(idx)].append(buf) @@ -851,7 +860,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return shard_arg_handlers[type(x._value)](x._value, devices, indices) + return shard_arg_handlers[type(x._value)](x._value, devices, indices, mode) # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1293,8 +1302,6 @@ class PmapExecutable(stages.XlaExecutable): ]) local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals) - # TODO(yashkatariya): Fix the input handling of `Array`s that span over - # multiple processes. Add multi-process tests for pmap. input_sharding_specs = [ _pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size, parts.local_num_partitions, arg_parts, aval, in_axis) @@ -1312,40 +1319,26 @@ class PmapExecutable(stages.XlaExecutable): if parts.local_out_parts is None: local_out_parts = (None,) * nouts - if config.jax_array: - global_unmapped_avals = [ + local_out_avals = [ + get_local_aval(aval, parts, lparts) + for aval, parts, lparts + in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)] + local_unmapped_avals = [ core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval) if out_axis is not None else aval - for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)] - global_out_specs = [ - _pmap_sharding_spec(replicas.num_global_replicas, pci.axis_size, - parts.num_partitions, op, aval, out_axis) - for op, aval, out_axis in safe_zip( - out_parts, shards.out_sharded_avals, pci.out_axes)] - pmap_shardings = _get_pmap_sharding(device_assignment, global_out_specs) - handle_outs = global_avals_to_results_handler( - global_unmapped_avals, pmap_shardings) - else: - local_out_avals = [ - get_local_aval(aval, parts, lparts) - for aval, parts, lparts - in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)] - local_unmapped_avals = [ - core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval) - if out_axis is not None else aval - for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)] - out_specs = [ - _pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size, - parts.local_num_partitions, out_parts, aval, out_axis) - for out_parts, aval, out_axis in safe_zip( - local_out_parts, local_out_avals, pci.out_axes)] - pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs) - handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings) + for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)] + out_specs = [ + _pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size, + parts.local_num_partitions, out_parts, aval, out_axis) + for out_parts, aval, out_axis in safe_zip( + local_out_parts, local_out_avals, pci.out_axes)] + pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs) + handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings) if hasattr(pci.backend, "compile_replicated"): execute_fun = pci.backend.compile_replicated( xla_computation, compile_options, pci.avals, input_indices, - in_shardings, handle_outs) + in_shardings, InputsHandlerMode.pmap, handle_outs) # TODO(frostig): need `compile_replicated` to give us the XLA executable return PmapExecutable(None, execute_fun, None, pci.avals) @@ -1354,7 +1347,7 @@ class PmapExecutable(stages.XlaExecutable): compiled = dispatch.compile_or_get_cached( pci.backend, xla_computation, compile_options, host_callbacks) handle_args = InputsHandler( - compiled.local_devices(), in_shardings, input_indices) + compiled.local_devices(), in_shardings, input_indices, InputsHandlerMode.pmap) execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, handle_outs, unordered_effects, keepalive) fingerprint = getattr(compiled, "fingerprint", None) @@ -1520,14 +1513,21 @@ def _safe_div(x, y): return result -class InputsHandler: - __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") +class InputsHandlerMode(enum.Enum): + pmap = 0 + pjit_or_xmap = 1 - def __init__(self, local_devices, in_shardings, input_indices): - self.handler = partial(shard_args, local_devices, input_indices) + +class InputsHandler: + __slots__ = ("handler", "local_devices", "in_shardings", "input_indices", + "mode") + + def __init__(self, local_devices, in_shardings, input_indices, mode): + self.handler = partial(shard_args, local_devices, input_indices, mode) self.local_devices = local_devices self.in_shardings = in_shardings self.input_indices = input_indices + self.mode = mode def __call__(self, input_buffers): return self.handler(input_buffers) @@ -1536,7 +1536,8 @@ class InputsHandler: return ("InputsHandler(\n" f"local_devices={self.local_devices},\n" f"in_shardings={self.in_shardings},\n" - f"input_indices={self.input_indices})") + f"input_indices={self.input_indices})\n" + f"mode={self.mode}") class ResultsHandler: @@ -1572,13 +1573,11 @@ def _get_sharding_specs( def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[Optional[ShapedArray]], local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler: - local_out_specs = _get_sharding_specs( - local_shardings, cast(Sequence[ShapedArray], unmapped_local_out_avals)) out_indices = [tuple(s.devices_indices_map(aval.shape).values()) for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] handlers = [ - local_aval_to_result_handler(aval, spec, idcs) - for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices) + local_aval_to_result_handler(aval, s, idcs) + for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices) ] return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals) @@ -2761,7 +2760,7 @@ class MeshExecutable(stages.XlaExecutable): global_out_avals, out_shardings) # type: ignore # arg-type unsafe_call = backend.compile_replicated( computation, compile_options, input_avals, input_indices, - in_shardings, handle_outs) + in_shardings, InputsHandlerMode.pjit_or_xmap, handle_outs) xla_executable = None else: with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " @@ -2784,7 +2783,7 @@ class MeshExecutable(stages.XlaExecutable): handle_outs = global_avals_to_results_handler( global_out_avals, out_shardings) # type: ignore # arg-type handle_args = InputsHandler(xla_executable.local_devices(), in_shardings, - input_indices) + input_indices, InputsHandlerMode.pjit_or_xmap) unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs, unordered_effects, keepalive) diff --git a/tests/BUILD b/tests/BUILD index 79e8edaf1..3515c6b1e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -189,6 +189,11 @@ jax_test( jax_test( name = "pjit_test", srcs = ["pjit_test.py"], + shard_count = { + "cpu": 5, + "gpu": 5, + "tpu": 5, + }, tags = ["multiaccelerator"], deps = [ "//jax:experimental", diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1358d4b1b..63c779aca 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2901,7 +2901,7 @@ class ShardArgsTest(jtu.JaxTestCase): x = np.arange(prod(shape)).reshape(shape) arg = make_arg(x) bufs = pxla.shard_args(jax.devices()[:nshards], - [indices], [arg]) + [indices], pxla.InputsHandlerMode.pmap, [arg]) self.assertEqual(len(bufs), 1) self.assertEqual(len(bufs[0]), nshards) for buf, idx in zip(bufs[0], indices):