Merge pull request #12705 from mattjj:fix-prng-key-array-device-put

PiperOrigin-RevId: 479813689
This commit is contained in:
jax authors 2022-10-08 11:39:05 -07:00
commit 674038ca47
6 changed files with 74 additions and 45 deletions

View File

@ -545,38 +545,25 @@ def _array_pmap_shard_arg(x, devices, indices, mode):
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
def _array_rest_shard_arg(x, devices, indices, mode):
if not x._committed:
def _array_rest_shard_arg(x: ArrayImpl, devices, indices, mode):
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
if not x.is_fully_addressable():
if tuple(x_indices) == tuple(indices):
return x._arrays
else:
return NotImplementedError("Cannot reshard an input that is not fully "
"addressable")
else:
if tuple(x_indices) == tuple(indices):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# Resharding starts here:
if isinstance(x.sharding, PmapSharding):
return pxla.device_put(x._value, devices, replicate=True)
if dispatch.is_single_device_sharding(x.sharding):
# This condition is to break the recursion that happens when only
# `pxla._shard_device_array` is used since it has `_multi_slice` in the
# implementation which is jitted. Eventually it calls back here and the
# recursion happens.
x_indices = tuple(x.sharding.addressable_devices_indices_map(x.shape).values())
if x_indices == indices:
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
return pxla._shard_device_array(x, devices, indices, mode)
else:
raise NotImplementedError('Resharding uncommitted arrays sharded over '
'multiple devices is not supported.')
# TODO(yashkatariya): Remove the special case here and don't move to another
# device if its already committed. There is a TODO in dispatch.py already
# for this.
if dispatch.is_single_device_sharding(x.sharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
# if the input Array containing PmapSharding takes the jit path
# i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most
# common case where this will happen.
# TODO(yashkatariya): Remove the special case here and don't move to another
# device if its already committed. There is a TODO in dispatch.py already
# for this.
elif isinstance(x.sharding, PmapSharding):
return pxla.device_put(x._value, devices, replicate=True)
else:
return x._arrays
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
def _array_shard_arg(x, devices, indices, mode):

View File

@ -1295,16 +1295,28 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
def _device_put_impl(
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
from jax._src import array, sharding
from jax.interpreters import pxla
try:
a = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
if isinstance(device, sharding.Sharding):
if not device.is_fully_addressable(): # type: ignore
s = device
if not s.is_fully_addressable(): # type: ignore
raise ValueError(
"device_put's second argument must be a Device or a Sharding which "
f"represents addressable devices, but got {sharding}")
if getattr(x, 'sharding', None) == device:
if getattr(x, 'sharding', None) == s:
return x
# TODO(mattjj,yashkatariya,phawkins): runtime fast resharding here?
return array.make_array_from_callback(x.shape, device, lambda idx: x[idx])
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
arg_handler = pxla.shard_arg_handlers[type(x)]
result_handler = pxla.global_aval_to_result_handler(a, s, True, False)
map_ = s.devices_indices_map(x.shape) # type: ignore
return result_handler(arg_handler(x, list(map_), list(map_.values()),
pxla.InputsHandlerMode.pjit_or_xmap))
# Only `Device` exists below. `Sharding` instance is handled above.
if isinstance(x, array.ArrayImpl):
@ -1320,11 +1332,6 @@ def _device_put_impl(
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)
try:
a = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
return aval_to_result_handler(device, a)(None, *device_put(x, device))

View File

@ -875,9 +875,10 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
if committed and not pxla.are_op_shardings_equal(
pjit_in_s._to_xla_op_sharding(arg.ndim),
arg_s._to_xla_op_sharding(arg.ndim)):
op = getattr(pjit_in_s, '_original_sharding', pjit_in_s)
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_in_s},\n'
f'Got pjit sharding: {op},\n'
f'arg sharding: {arg_s}')
resolved_in_shardings.append(pjit_in_s)

View File

@ -592,7 +592,7 @@ local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_arra
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
) -> Callable[[Sequence[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:

View File

@ -1487,10 +1487,12 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(y2[1][1], np.ndarray)
assert np.all(y2[1][1] == 3 * x)
@jax_config.jax_array(True)
def test_device_put_sharding(self):
mesh = maps.Mesh(jax.devices(), ('x',))
s = sharding.MeshPspecSharding(mesh, P('x'))
x = jnp.arange(len(jax.devices()))
y = jax.device_put(x, s)
self.assertEqual(y.sharding, s)
self.assertArraysAllClose(y, x)
@ -1508,11 +1510,6 @@ class APITest(jtu.JaxTestCase):
self.assertArraysAllClose(u, y)
self.assertEqual(u.device(), jax.devices()[0])
# TODO(frostig): make this pass with JAX_ENABLE_CUSTOM_PRNG=1
# # this can cover opaque dtypes
# x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
# jax.device_put(x, s) # doesn't crash
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)

View File

@ -2197,6 +2197,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out, arr)
self.assertLen(out.addressable_shards, 8)
@jax_array(True)
def test_pjit_uncommitted_array_in_axis_resources_reshard(self):
arr = jnp.arange(16).reshape(8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
with mesh:
out = pjit(lambda x: x, in_axis_resources=P('x', 'y'))(arr)
self.assertArraysEqual(out, arr)
self.assertLen(out.addressable_shards, 8)
for s in out.addressable_shards:
self.assertArraysEqual(s.data, arr[s.index])
self.assertEqual(s.replica_id, 0)
@jax_array(True)
def test_pjit_uncommitted_array_and_committed_array(self):
shape = (8, 2)
@ -2291,6 +2303,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
@jax_array(True)
def test_device_put_sharding_prng(self):
mesh = jtu.create_global_mesh((8,), ('x',))
s = MeshPspecSharding(mesh, P('x'))
x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
y = jax.device_put(x, s)
if config.jax_enable_custom_prng:
self.assertIsInstance(y, jax.random.KeyArray)
self.assertEqual(y.sharding, s)
@jax_array(True)
def test_device_put_on_different_sharding(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
x = jnp.arange(8).reshape(4, 2)
s1 = MeshPspecSharding(mesh, P('x'))
a = jax.device_put(x, s1)
self.assertEqual(a.sharding, s1)
s2 = MeshPspecSharding(mesh, P('x', 'y'))
b = jax.device_put(a, s2)
self.assertEqual(b.sharding, s2)
class ArrayCppPjitTest(ArrayPjitTest):