mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Make eager pmap tests pass with Array
. Also add a slow path for Array in pmap
similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
PiperOrigin-RevId: 468843310
This commit is contained in:
parent
3a2f25ff31
commit
f905d989c1
@ -1904,42 +1904,6 @@ class PmapCallInfo(NamedTuple):
|
||||
devices: Optional[Sequence[xc.Device]]
|
||||
|
||||
|
||||
def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
|
||||
from jax.experimental.sharding import PmapSharding, SingleDeviceSharding
|
||||
from jax.experimental.array import Array
|
||||
|
||||
if not args:
|
||||
return
|
||||
|
||||
first_device_assignment = None
|
||||
for a, i in safe_zip(args, in_axes_flat):
|
||||
if not isinstance(a, Array):
|
||||
continue
|
||||
if isinstance(a.sharding, SingleDeviceSharding):
|
||||
continue
|
||||
if not isinstance(a.sharding, PmapSharding):
|
||||
raise NotImplementedError('pmap only works with PmapSharding.')
|
||||
if first_device_assignment is None:
|
||||
first_device_assignment = a.sharding._device_assignment
|
||||
arr_sharding = a.sharding.sharded_dim
|
||||
arr_device_assignment = a.sharding._device_assignment
|
||||
if arr_sharding != i:
|
||||
raise ValueError('Array and pmap sharding does not match. Got pmap '
|
||||
f'sharding: {i}, Array sharding: {arr_sharding} for '
|
||||
f'arg: {a}')
|
||||
if (in_devices is not None and
|
||||
arr_device_assignment is not None and
|
||||
arr_device_assignment != in_devices):
|
||||
raise ValueError('Devices passed to pmap and Array should be equal. '
|
||||
f'Got pmap devices: {in_devices}, Array devices: '
|
||||
f'{arr_device_assignment} for arg: {a}')
|
||||
if (in_devices is None and
|
||||
arr_device_assignment != first_device_assignment):
|
||||
raise ValueError('Devices of all `Array` inputs should be the same. '
|
||||
f'Got array device: {arr_device_assignment}, '
|
||||
f'another array device: {first_device_assignment}')
|
||||
|
||||
|
||||
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
|
||||
f = lu.wrap_init(fun)
|
||||
@ -1982,9 +1946,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
if config.jax_array:
|
||||
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)
|
||||
|
||||
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
|
||||
raise NotImplementedError("None out_axes in pmap are not supported yet")
|
||||
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,
|
||||
|
@ -120,7 +120,7 @@ class Array:
|
||||
if config.jax_enable_checks:
|
||||
assert all(db.dtype == self.dtype for db in self._arrays), (
|
||||
"Input arrays to `Array` must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in self._arrays]}")
|
||||
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")
|
||||
|
||||
# Rearrange arrays based on the device assignment.
|
||||
if isinstance(sharding, XLACompatibleSharding):
|
||||
@ -378,16 +378,23 @@ def _device_put_array(x, device: Optional[Device]):
|
||||
dispatch.device_put_handlers[Array] = _device_put_array
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices, mode):
|
||||
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
|
||||
# code paths.
|
||||
if mode == pxla.InputsHandlerMode.pmap:
|
||||
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
|
||||
# `_check_in_pmap_sharding_with_arrays` function.
|
||||
if isinstance(x.sharding, SingleDeviceSharding):
|
||||
return pxla._shard_device_array(x, devices, indices, mode)
|
||||
def _array_pmap_shard_arg(x, devices, indices, mode):
|
||||
if isinstance(x.sharding, SingleDeviceSharding):
|
||||
return pxla._shard_device_array(x, devices, indices, mode)
|
||||
|
||||
# If the sharding of Array does not match pmap's sharding then take the slow
|
||||
# path which is similar to what SDA does. This slow path reroute only happens
|
||||
# for `pmap`.
|
||||
if indices == tuple(x.sharding.devices_indices_map(x.shape).values()):
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
else:
|
||||
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices, mode):
|
||||
if mode == pxla.InputsHandlerMode.pmap:
|
||||
return _array_pmap_shard_arg(x, devices, indices, mode)
|
||||
else:
|
||||
return x._arrays
|
||||
pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
|
@ -258,13 +258,6 @@ class PmapSharding(XLACompatibleSharding):
|
||||
def device_set(self) -> Set[Device]:
|
||||
return set(self.devices.flat)
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def sharded_dim(self):
|
||||
for i, s in enumerate(self.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
return i
|
||||
return None
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
|
@ -853,8 +853,16 @@ def _hashable_index(idx):
|
||||
# The fast path is handled directly in shard_args().
|
||||
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
||||
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
|
||||
from jax.experimental.array import Array
|
||||
|
||||
candidates = defaultdict(list)
|
||||
for buf, idx in safe_zip(x.device_buffers, x.indices):
|
||||
if isinstance(x, Array):
|
||||
bufs = x._arrays
|
||||
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
||||
else:
|
||||
bufs = x.device_buffers
|
||||
arr_indices = x.indices
|
||||
for buf, idx in safe_zip(bufs, arr_indices):
|
||||
candidates[_hashable_index(idx)].append(buf)
|
||||
|
||||
bufs = []
|
||||
@ -977,10 +985,13 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)):
|
||||
# We don't want to donate if it's already sharded.
|
||||
donate_argnums_ = ()
|
||||
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
|
||||
out_axes=out_axis, devices=devices, backend=backend,
|
||||
donate_argnums=donate_argnums_)(
|
||||
np.arange(axis_size), outval)
|
||||
out = jax.pmap(
|
||||
lambda _, x: x,
|
||||
in_axes=(0, out_axis_src.get(axis_name)),
|
||||
out_axes=out_axis,
|
||||
devices=(None if devices is None else list(devices)),
|
||||
backend=backend,
|
||||
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
|
||||
new_outvals.append(out)
|
||||
return new_outvals
|
||||
|
||||
@ -1000,8 +1011,13 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
|
||||
for i, name in reversed(list(enumerate(names))):
|
||||
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
||||
if any(in_axis is not None for in_axis in in_axes):
|
||||
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
|
||||
backend=info.backend, devices=info.devices)
|
||||
f = jax.pmap(
|
||||
f,
|
||||
in_axes=in_axes,
|
||||
axis_name=name,
|
||||
out_axes=0,
|
||||
backend=info.backend,
|
||||
devices=(None if info.devices is None else list(info.devices)))
|
||||
used_names.append(name)
|
||||
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
||||
return f, out_shard_axes
|
||||
|
@ -528,8 +528,6 @@ jax_test(
|
||||
jax_test(
|
||||
name = "pmap_test",
|
||||
srcs = ["pmap_test.py"],
|
||||
# pmap already has array tests inside it.
|
||||
disable_configs = ["cpu_jax_array"],
|
||||
shard_count = {
|
||||
"cpu": 15,
|
||||
"gpu": 30,
|
||||
|
@ -2707,6 +2707,11 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
|
||||
def testNoCopyIndexing1D(self):
|
||||
# TODO(https://github.com/google/jax/issues/12016): Implement no copy
|
||||
# indexing similar to SDA.
|
||||
if config.jax_array:
|
||||
self.skipTest('No copy indexing is not implemented for Array yet.')
|
||||
|
||||
shape = (8, 4)
|
||||
|
||||
if jax.device_count() < shape[0]:
|
||||
@ -2798,16 +2803,24 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_repr(self):
|
||||
x = jax.device_put_replicated(1, jax.devices())
|
||||
self.assertStartsWith(repr(x), 'ShardedDeviceArray')
|
||||
if config.jax_array:
|
||||
arr = 'Array'
|
||||
else:
|
||||
arr = 'ShardedDeviceArray'
|
||||
self.assertStartsWith(repr(x), arr)
|
||||
|
||||
def test_delete_is_idempotent(self):
|
||||
x = jax.device_put_replicated(1, jax.devices())
|
||||
x.delete()
|
||||
x.delete()
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'ShardedDeviceArray has been deleted.'):
|
||||
_ = x[0]
|
||||
if config.jax_array:
|
||||
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
||||
_ = x[0]
|
||||
else:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'ShardedDeviceArray has been deleted.'):
|
||||
_ = x[0]
|
||||
|
||||
|
||||
class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
@ -3075,11 +3088,14 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
f = jax.pmap(lambda x: x, in_axes=0, out_axes=0)
|
||||
with jax._src.config.jax_array(True):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
("Array and pmap sharding does not match. Got pmap sharding: 0, "
|
||||
"Array sharding: None")):
|
||||
f(a1)
|
||||
out_array = f(a1)
|
||||
|
||||
with jax._src.config.jax_array(False):
|
||||
out_sda = f(a1)
|
||||
|
||||
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
|
||||
self.assertArraysEqual(out_array.sharding.devices,
|
||||
[d.device() for d in out_sda.device_buffers])
|
||||
|
||||
def test_pmap_array_devices_mismatch(self):
|
||||
if jax.device_count() <= 1:
|
||||
@ -3090,52 +3106,16 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
f = jax.pmap(lambda x: x, devices=jax.devices()[::-1])
|
||||
with jax._src.config.jax_array(True):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Devices passed to pmap and Array should be equal."):
|
||||
f(a1)
|
||||
out_array = f(a1)
|
||||
|
||||
def test_pmap_array_devices_mismatch_between_arrays(self):
|
||||
if jax.device_count() <= 1:
|
||||
raise unittest.SkipTest('Skipping because this test needs more than '
|
||||
'1 device.')
|
||||
input_shape = (jax.device_count(), 2)
|
||||
a1, _ = create_input_array_for_pmap(input_shape)
|
||||
a2, _ = create_input_array_for_pmap(input_shape, devices=jax.devices()[::-1])
|
||||
with jax._src.config.jax_array(False):
|
||||
out_sda = f(a1)
|
||||
|
||||
f = jax.pmap(lambda x, y: (x, y))
|
||||
with jax._src.config.jax_array(True):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Devices of all `Array` inputs should be the same."):
|
||||
f(a1, a2)
|
||||
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
|
||||
self.assertArraysEqual(out_array.sharding.devices,
|
||||
[d.device() for d in out_sda.device_buffers])
|
||||
|
||||
|
||||
class ArrayPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.array_enabled = config.jax_array
|
||||
config.update('jax_array', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('jax_array', self.array_enabled)
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class ArrayPythonPmapTest(ArrayPmapMixin, PythonPmapTest):
|
||||
pass
|
||||
|
||||
class ArrayCppPmapTest(ArrayPmapMixin, CppPmapTest):
|
||||
pass
|
||||
|
||||
class ArrayVmapOfPmapTest(ArrayPmapMixin, VmapOfPmapTest):
|
||||
pass
|
||||
|
||||
class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
|
||||
pass
|
||||
|
||||
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
|
||||
pass
|
||||
|
||||
class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user