Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.

PiperOrigin-RevId: 466849614
This commit is contained in:
Yash Katariya 2022-08-10 20:11:06 -07:00 committed by jax authors
parent 0a783ca156
commit 33c4fc4fe2
6 changed files with 112 additions and 95 deletions

View File

@ -1897,39 +1897,41 @@ class PmapCallInfo(NamedTuple):
def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
from jax.experimental import sharding
from jax.experimental.sharding import PmapSharding
from jax.experimental.array import Array
if not args:
return
if in_devices is not None:
in_devices = np.array(in_devices)
first_arr_devices = args[0].sharding.devices
first_device_assignment = None
for a, i in safe_zip(args, in_axes_flat):
assert isinstance(a.sharding, sharding.PmapSharding)
if not isinstance(a, Array):
continue
if not isinstance(a.sharding, PmapSharding):
raise NotImplementedError('pmap only works with PmapSharding.')
if first_device_assignment is None:
first_device_assignment = a.sharding._device_assignment
arr_sharding = a.sharding.sharded_dim
arr_devices = a.sharding.devices
arr_device_assignment = a.sharding._device_assignment
if arr_sharding != i:
raise ValueError('Array and pmap sharding does not match. Got pmap '
f'sharding: {i}, Array sharding: {arr_sharding} for '
f'arg: {a}')
if (in_devices is not None and
arr_devices is not None and
not np.array_equal(arr_devices, in_devices)):
arr_device_assignment is not None and
arr_device_assignment != in_devices):
raise ValueError('Devices passed to pmap and Array should be equal. '
f'Got pmap devices: {devices}, Array devices: '
f'{arr_devices} for arg: {a}')
f'Got pmap devices: {in_devices}, Array devices: '
f'{arr_device_assignment} for arg: {a}')
if (in_devices is None and
not np.array_equal(arr_devices, first_arr_devices)):
arr_device_assignment != first_device_assignment):
raise ValueError('Devices of all `Array` inputs should be the same. '
f'Got array device: {arr_devices}, '
f'another array device: {first_arr_devices}')
return first_arr_devices
f'Got array device: {arr_device_assignment}, '
f'another array device: {first_device_assignment}')
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, devices, args, kwargs):
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
f = lu.wrap_init(fun)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
@ -1971,13 +1973,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
flat_fun, out_tree = flatten_fun(f, in_tree)
if config.jax_array:
from jax.experimental.array import Array
if any(not isinstance(a, Array) for a in args):
raise ValueError('All arguments to pmap when `config.jax_array` is '
'enabled should be `Array`s.')
arr_devices = _check_in_pmap_sharding_with_arrays(args, in_axes_flat, devices)
if devices is None and arr_devices is not None:
devices = arr_devices
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
@ -2011,7 +2007,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk,
devices=None if devices is None else tuple(devices))
devices=None if in_devices is None else tuple(in_devices))
def _get_f_mapped(
@ -2199,8 +2195,9 @@ def _cpp_pmap(
return out, fastpath_data
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
static_broadcasted_tuple, pxla._shard_arg)
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
partial(pxla._shard_arg, mode=pxla.InputsHandlerMode.pmap))
pmap_f = wraps(fun)(cpp_mapped_f)

View File

@ -21,7 +21,7 @@ from jax import core
from jax._src import api_util
from jax._src import dispatch
from jax._src.config import config
from jax._src.util import prod
from jax._src.util import prod, safe_zip
from jax._src.lib import xla_client as xc
from jax._src.api import device_put
from jax.interpreters import pxla, xla
@ -261,12 +261,26 @@ def _device_put_array(x, device: Optional[Device]):
dispatch.device_put_handlers[Array] = _device_put_array
def _array_shard_arg(x, devices, indices):
return x._arrays
def _array_shard_arg(x, devices, indices, mode):
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
# code paths.
if mode == pxla.InputsHandlerMode.pmap:
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
# `_check_in_pmap_sharding_with_arrays` function.
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg
def _array_result_handler(global_aval, out_sharding):
def _array_global_result_handler(global_aval, out_sharding):
return lambda bufs: Array(global_aval.shape, out_sharding, bufs, committed=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_result_handler
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
def _array_local_result_handler(aval, sharding, indices):
return lambda bufs: Array(aval.shape, sharding, bufs, committed=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler

View File

@ -561,7 +561,9 @@ xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
lambda x: core.ShapedArray(x.shape, x.dtype)
def _gda_shard_arg(x, devices, indices):
def _gda_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
raise RuntimeError('GDA is not supported with pmap.')
return x._device_buffers
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg

View File

@ -368,7 +368,7 @@ def spec_to_indices(shape: Tuple[int, ...],
def identity(x): return x
def _shard_arg(arg, devices, arg_indices):
def _shard_arg(arg, devices, arg_indices, mode):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
@ -378,6 +378,7 @@ def _shard_arg(arg, devices, arg_indices):
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
mode: An enum telling whether shard_arg is executed via pmap or pjit/xmap.
"""
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
# The shard_arg_handlers allow an extensible set of types to be sharded, but
@ -390,12 +391,13 @@ def _shard_arg(arg, devices, arg_indices):
]
else:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, mode)
@profiler.annotate_function
def shard_args(devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
mode: InputsHandlerMode,
args) -> Sequence[Sequence[xb.xla_client.Buffer]]:
"""Shard each argument data array along its leading axis.
@ -411,16 +413,16 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
return [_shard_arg(arg, devices, indices[i], mode) for i, arg in enumerate(args)]
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
def _shard_array(x, devices, indices):
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {}
def _shard_array(x, devices, indices, mode):
return device_put([x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_device_array(x, devices, indices):
def _shard_device_array(x, devices, indices, mode):
start_indices, limit_indices, removed_dims = unzip3(
_as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
@ -524,9 +526,15 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
return PartitionSpec(*partitions)
class OutputType(enum.Enum):
Array = 0
GlobalDeviceArray = 1
ShardedDeviceArray = 2
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding_spec: Optional[ShardingSpec],
sharding: XLACompatibleSharding,
indices: Optional[Tuple[Index]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
@ -543,24 +551,25 @@ def local_aval_to_result_handler(
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
if config.jax_array:
output_type = OutputType.Array
else:
output_type = OutputType.ShardedDeviceArray
try:
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
return local_result_handlers[(type(aval), output_type)](aval, sharding, indices)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
local_result_handlers[ShapedArray] = sda_array_result_handler
local_result_handlers[ConcreteArray] = sda_array_result_handler
class OutputType(enum.Enum):
Array = 0
GlobalDeviceArray = 1
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
def global_aval_to_result_handler(
@ -839,7 +848,7 @@ def _hashable_index(idx):
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices):
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
candidates = defaultdict(list)
for buf, idx in safe_zip(x.device_buffers, x.indices):
candidates[_hashable_index(idx)].append(buf)
@ -851,7 +860,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return shard_arg_handlers[type(x._value)](x._value, devices, indices)
return shard_arg_handlers[type(x._value)](x._value, devices, indices, mode)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
@ -1293,8 +1302,6 @@ class PmapExecutable(stages.XlaExecutable):
])
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
# TODO(yashkatariya): Fix the input handling of `Array`s that span over
# multiple processes. Add multi-process tests for pmap.
input_sharding_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, arg_parts, aval, in_axis)
@ -1312,40 +1319,26 @@ class PmapExecutable(stages.XlaExecutable):
if parts.local_out_parts is None:
local_out_parts = (None,) * nouts
if config.jax_array:
global_unmapped_avals = [
local_out_avals = [
get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
local_unmapped_avals = [
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
global_out_specs = [
_pmap_sharding_spec(replicas.num_global_replicas, pci.axis_size,
parts.num_partitions, op, aval, out_axis)
for op, aval, out_axis in safe_zip(
out_parts, shards.out_sharded_avals, pci.out_axes)]
pmap_shardings = _get_pmap_sharding(device_assignment, global_out_specs)
handle_outs = global_avals_to_results_handler(
global_unmapped_avals, pmap_shardings)
else:
local_out_avals = [
get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
local_unmapped_avals = [
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
out_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, out_parts, aval, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
out_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, out_parts, aval, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
if hasattr(pci.backend, "compile_replicated"):
execute_fun = pci.backend.compile_replicated(
xla_computation, compile_options, pci.avals, input_indices,
in_shardings, handle_outs)
in_shardings, InputsHandlerMode.pmap, handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
@ -1354,7 +1347,7 @@ class PmapExecutable(stages.XlaExecutable):
compiled = dispatch.compile_or_get_cached(
pci.backend, xla_computation, compile_options, host_callbacks)
handle_args = InputsHandler(
compiled.local_devices(), in_shardings, input_indices)
compiled.local_devices(), in_shardings, input_indices, InputsHandlerMode.pmap)
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
handle_outs, unordered_effects, keepalive)
fingerprint = getattr(compiled, "fingerprint", None)
@ -1520,14 +1513,21 @@ def _safe_div(x, y):
return result
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
class InputsHandlerMode(enum.Enum):
pmap = 0
pjit_or_xmap = 1
def __init__(self, local_devices, in_shardings, input_indices):
self.handler = partial(shard_args, local_devices, input_indices)
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices",
"mode")
def __init__(self, local_devices, in_shardings, input_indices, mode):
self.handler = partial(shard_args, local_devices, input_indices, mode)
self.local_devices = local_devices
self.in_shardings = in_shardings
self.input_indices = input_indices
self.mode = mode
def __call__(self, input_buffers):
return self.handler(input_buffers)
@ -1536,7 +1536,8 @@ class InputsHandler:
return ("InputsHandler(\n"
f"local_devices={self.local_devices},\n"
f"in_shardings={self.in_shardings},\n"
f"input_indices={self.input_indices})")
f"input_indices={self.input_indices})\n"
f"mode={self.mode}")
class ResultsHandler:
@ -1572,13 +1573,11 @@ def _get_sharding_specs(
def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[Optional[ShapedArray]],
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
local_out_specs = _get_sharding_specs(
local_shardings, cast(Sequence[ShapedArray], unmapped_local_out_avals))
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [
local_aval_to_result_handler(aval, spec, idcs)
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
local_aval_to_result_handler(aval, s, idcs)
for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices)
]
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
@ -2761,7 +2760,7 @@ class MeshExecutable(stages.XlaExecutable):
global_out_avals, out_shardings) # type: ignore # arg-type
unsafe_call = backend.compile_replicated(
computation, compile_options, input_avals, input_indices,
in_shardings, handle_outs)
in_shardings, InputsHandlerMode.pjit_or_xmap, handle_outs)
xla_executable = None
else:
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
@ -2784,7 +2783,7 @@ class MeshExecutable(stages.XlaExecutable):
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings) # type: ignore # arg-type
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
input_indices)
input_indices, InputsHandlerMode.pjit_or_xmap)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
handle_outs, unordered_effects, keepalive)

View File

@ -189,6 +189,11 @@ jax_test(
jax_test(
name = "pjit_test",
srcs = ["pjit_test.py"],
shard_count = {
"cpu": 5,
"gpu": 5,
"tpu": 5,
},
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",

View File

@ -2901,7 +2901,7 @@ class ShardArgsTest(jtu.JaxTestCase):
x = np.arange(prod(shape)).reshape(shape)
arg = make_arg(x)
bufs = pxla.shard_args(jax.devices()[:nshards],
[indices], [arg])
[indices], pxla.InputsHandlerMode.pmap, [arg])
self.assertEqual(len(bufs), 1)
self.assertEqual(len(bufs[0]), nshards)
for buf, idx in zip(bufs[0], indices):