Make eager pmap tests pass with Array. Also add a slow path for Array in pmap similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).

PiperOrigin-RevId: 468843310
This commit is contained in:
Yash Katariya 2022-08-19 21:36:43 -07:00 committed by jax authors
parent 3a2f25ff31
commit f905d989c1
6 changed files with 70 additions and 115 deletions

View File

@ -1904,42 +1904,6 @@ class PmapCallInfo(NamedTuple):
devices: Optional[Sequence[xc.Device]]
def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
from jax.experimental.sharding import PmapSharding, SingleDeviceSharding
from jax.experimental.array import Array
if not args:
return
first_device_assignment = None
for a, i in safe_zip(args, in_axes_flat):
if not isinstance(a, Array):
continue
if isinstance(a.sharding, SingleDeviceSharding):
continue
if not isinstance(a.sharding, PmapSharding):
raise NotImplementedError('pmap only works with PmapSharding.')
if first_device_assignment is None:
first_device_assignment = a.sharding._device_assignment
arr_sharding = a.sharding.sharded_dim
arr_device_assignment = a.sharding._device_assignment
if arr_sharding != i:
raise ValueError('Array and pmap sharding does not match. Got pmap '
f'sharding: {i}, Array sharding: {arr_sharding} for '
f'arg: {a}')
if (in_devices is not None and
arr_device_assignment is not None and
arr_device_assignment != in_devices):
raise ValueError('Devices passed to pmap and Array should be equal. '
f'Got pmap devices: {in_devices}, Array devices: '
f'{arr_device_assignment} for arg: {a}')
if (in_devices is None and
arr_device_assignment != first_device_assignment):
raise ValueError('Devices of all `Array` inputs should be the same. '
f'Got array device: {arr_device_assignment}, '
f'another array device: {first_device_assignment}')
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
f = lu.wrap_init(fun)
@ -1982,9 +1946,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
flat_fun, out_tree = flatten_fun(f, in_tree)
if config.jax_array:
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,

View File

@ -120,7 +120,7 @@ class Array:
if config.jax_enable_checks:
assert all(db.dtype == self.dtype for db in self._arrays), (
"Input arrays to `Array` must have matching dtypes, "
f"got: {[db.dtype for db in self._arrays]}")
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")
# Rearrange arrays based on the device assignment.
if isinstance(sharding, XLACompatibleSharding):
@ -378,16 +378,23 @@ def _device_put_array(x, device: Optional[Device]):
dispatch.device_put_handlers[Array] = _device_put_array
def _array_shard_arg(x, devices, indices, mode):
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
# code paths.
if mode == pxla.InputsHandlerMode.pmap:
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
# `_check_in_pmap_sharding_with_arrays` function.
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)
def _array_pmap_shard_arg(x, devices, indices, mode):
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)
# If the sharding of Array does not match pmap's sharding then take the slow
# path which is similar to what SDA does. This slow path reroute only happens
# for `pmap`.
if indices == tuple(x.sharding.devices_indices_map(x.shape).values()):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg

View File

@ -258,13 +258,6 @@ class PmapSharding(XLACompatibleSharding):
def device_set(self) -> Set[Device]:
return set(self.devices.flat)
@pxla.maybe_cached_property
def sharded_dim(self):
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
return i
return None
@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:

View File

@ -853,8 +853,16 @@ def _hashable_index(idx):
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
from jax.experimental.array import Array
candidates = defaultdict(list)
for buf, idx in safe_zip(x.device_buffers, x.indices):
if isinstance(x, Array):
bufs = x._arrays
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
bufs = x.device_buffers
arr_indices = x.indices
for buf, idx in safe_zip(bufs, arr_indices):
candidates[_hashable_index(idx)].append(buf)
bufs = []
@ -977,10 +985,13 @@ def _emap_impl(fun: lu.WrappedFun, *args,
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis, devices=devices, backend=backend,
donate_argnums=donate_argnums_)(
np.arange(axis_size), outval)
out = jax.pmap(
lambda _, x: x,
in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis,
devices=(None if devices is None else list(devices)),
backend=backend,
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
new_outvals.append(out)
return new_outvals
@ -1000,8 +1011,13 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
backend=info.backend, devices=info.devices)
f = jax.pmap(
f,
in_axes=in_axes,
axis_name=name,
out_axes=0,
backend=info.backend,
devices=(None if info.devices is None else list(info.devices)))
used_names.append(name)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes

View File

@ -528,8 +528,6 @@ jax_test(
jax_test(
name = "pmap_test",
srcs = ["pmap_test.py"],
# pmap already has array tests inside it.
disable_configs = ["cpu_jax_array"],
shard_count = {
"cpu": 15,
"gpu": 30,

View File

@ -2707,6 +2707,11 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
def testNoCopyIndexing1D(self):
# TODO(https://github.com/google/jax/issues/12016): Implement no copy
# indexing similar to SDA.
if config.jax_array:
self.skipTest('No copy indexing is not implemented for Array yet.')
shape = (8, 4)
if jax.device_count() < shape[0]:
@ -2798,16 +2803,24 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
def test_repr(self):
x = jax.device_put_replicated(1, jax.devices())
self.assertStartsWith(repr(x), 'ShardedDeviceArray')
if config.jax_array:
arr = 'Array'
else:
arr = 'ShardedDeviceArray'
self.assertStartsWith(repr(x), arr)
def test_delete_is_idempotent(self):
x = jax.device_put_replicated(1, jax.devices())
x.delete()
x.delete()
with self.assertRaisesRegex(ValueError,
'ShardedDeviceArray has been deleted.'):
_ = x[0]
if config.jax_array:
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = x[0]
else:
with self.assertRaisesRegex(ValueError,
'ShardedDeviceArray has been deleted.'):
_ = x[0]
class SpecToIndicesTest(jtu.JaxTestCase):
@ -3075,11 +3088,14 @@ class ArrayPmapTest(jtu.JaxTestCase):
f = jax.pmap(lambda x: x, in_axes=0, out_axes=0)
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError,
("Array and pmap sharding does not match. Got pmap sharding: 0, "
"Array sharding: None")):
f(a1)
out_array = f(a1)
with jax._src.config.jax_array(False):
out_sda = f(a1)
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
self.assertArraysEqual(out_array.sharding.devices,
[d.device() for d in out_sda.device_buffers])
def test_pmap_array_devices_mismatch(self):
if jax.device_count() <= 1:
@ -3090,52 +3106,16 @@ class ArrayPmapTest(jtu.JaxTestCase):
f = jax.pmap(lambda x: x, devices=jax.devices()[::-1])
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError, "Devices passed to pmap and Array should be equal."):
f(a1)
out_array = f(a1)
def test_pmap_array_devices_mismatch_between_arrays(self):
if jax.device_count() <= 1:
raise unittest.SkipTest('Skipping because this test needs more than '
'1 device.')
input_shape = (jax.device_count(), 2)
a1, _ = create_input_array_for_pmap(input_shape)
a2, _ = create_input_array_for_pmap(input_shape, devices=jax.devices()[::-1])
with jax._src.config.jax_array(False):
out_sda = f(a1)
f = jax.pmap(lambda x, y: (x, y))
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError, "Devices of all `Array` inputs should be the same."):
f(a1, a2)
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
self.assertArraysEqual(out_array.sharding.devices,
[d.device() for d in out_sda.device_buffers])
class ArrayPmapMixin:
def setUp(self):
super().setUp()
self.array_enabled = config.jax_array
config.update('jax_array', True)
def tearDown(self):
config.update('jax_array', self.array_enabled)
super().tearDown()
class ArrayPythonPmapTest(ArrayPmapMixin, PythonPmapTest):
pass
class ArrayCppPmapTest(ArrayPmapMixin, CppPmapTest):
pass
class ArrayVmapOfPmapTest(ArrayPmapMixin, VmapOfPmapTest):
pass
class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
pass
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
pass
class EagerPmapMixin:
def setUp(self):