mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
This commit is contained in:
commit
674038ca47
@ -545,38 +545,25 @@ def _array_pmap_shard_arg(x, devices, indices, mode):
|
||||
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
|
||||
|
||||
|
||||
def _array_rest_shard_arg(x, devices, indices, mode):
|
||||
if not x._committed:
|
||||
def _array_rest_shard_arg(x: ArrayImpl, devices, indices, mode):
|
||||
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
|
||||
if not x.is_fully_addressable():
|
||||
if tuple(x_indices) == tuple(indices):
|
||||
return x._arrays
|
||||
else:
|
||||
return NotImplementedError("Cannot reshard an input that is not fully "
|
||||
"addressable")
|
||||
else:
|
||||
if tuple(x_indices) == tuple(indices):
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
# Resharding starts here:
|
||||
if isinstance(x.sharding, PmapSharding):
|
||||
return pxla.device_put(x._value, devices, replicate=True)
|
||||
if dispatch.is_single_device_sharding(x.sharding):
|
||||
# This condition is to break the recursion that happens when only
|
||||
# `pxla._shard_device_array` is used since it has `_multi_slice` in the
|
||||
# implementation which is jitted. Eventually it calls back here and the
|
||||
# recursion happens.
|
||||
x_indices = tuple(x.sharding.addressable_devices_indices_map(x.shape).values())
|
||||
if x_indices == indices:
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
return pxla._shard_device_array(x, devices, indices, mode)
|
||||
else:
|
||||
raise NotImplementedError('Resharding uncommitted arrays sharded over '
|
||||
'multiple devices is not supported.')
|
||||
# TODO(yashkatariya): Remove the special case here and don't move to another
|
||||
# device if its already committed. There is a TODO in dispatch.py already
|
||||
# for this.
|
||||
if dispatch.is_single_device_sharding(x.sharding):
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
# If PmapSharding exists, then do a round trip via host. This will happen
|
||||
# if the input Array containing PmapSharding takes the jit path
|
||||
# i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most
|
||||
# common case where this will happen.
|
||||
# TODO(yashkatariya): Remove the special case here and don't move to another
|
||||
# device if its already committed. There is a TODO in dispatch.py already
|
||||
# for this.
|
||||
elif isinstance(x.sharding, PmapSharding):
|
||||
return pxla.device_put(x._value, devices, replicate=True)
|
||||
else:
|
||||
return x._arrays
|
||||
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices, mode):
|
||||
|
@ -1295,16 +1295,28 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
||||
def _device_put_impl(
|
||||
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
||||
from jax._src import array, sharding
|
||||
from jax.interpreters import pxla
|
||||
|
||||
try:
|
||||
a = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
|
||||
if isinstance(device, sharding.Sharding):
|
||||
if not device.is_fully_addressable(): # type: ignore
|
||||
s = device
|
||||
if not s.is_fully_addressable(): # type: ignore
|
||||
raise ValueError(
|
||||
"device_put's second argument must be a Device or a Sharding which "
|
||||
f"represents addressable devices, but got {sharding}")
|
||||
if getattr(x, 'sharding', None) == device:
|
||||
if getattr(x, 'sharding', None) == s:
|
||||
return x
|
||||
# TODO(mattjj,yashkatariya,phawkins): runtime fast resharding here?
|
||||
return array.make_array_from_callback(x.shape, device, lambda idx: x[idx])
|
||||
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
|
||||
arg_handler = pxla.shard_arg_handlers[type(x)]
|
||||
result_handler = pxla.global_aval_to_result_handler(a, s, True, False)
|
||||
map_ = s.devices_indices_map(x.shape) # type: ignore
|
||||
return result_handler(arg_handler(x, list(map_), list(map_.values()),
|
||||
pxla.InputsHandlerMode.pjit_or_xmap))
|
||||
|
||||
# Only `Device` exists below. `Sharding` instance is handled above.
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
@ -1320,11 +1332,6 @@ def _device_put_impl(
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
||||
try:
|
||||
a = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
return aval_to_result_handler(device, a)(None, *device_put(x, device))
|
||||
|
||||
|
||||
|
@ -875,9 +875,10 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
|
||||
if committed and not pxla.are_op_shardings_equal(
|
||||
pjit_in_s._to_xla_op_sharding(arg.ndim),
|
||||
arg_s._to_xla_op_sharding(arg.ndim)):
|
||||
op = getattr(pjit_in_s, '_original_sharding', pjit_in_s)
|
||||
raise ValueError('Sharding passed to pjit does not match the sharding '
|
||||
'on the respective arg. '
|
||||
f'Got pjit sharding: {pjit_in_s},\n'
|
||||
f'Got pjit sharding: {op},\n'
|
||||
f'arg sharding: {arg_s}')
|
||||
resolved_in_shardings.append(pjit_in_s)
|
||||
|
||||
|
@ -592,7 +592,7 @@ local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_arra
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue, out_sharding, committed: bool,
|
||||
is_out_sharding_from_xla: bool
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
) -> Callable[[Sequence[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
|
@ -1487,10 +1487,12 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y2[1][1], np.ndarray)
|
||||
assert np.all(y2[1][1] == 3 * x)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_device_put_sharding(self):
|
||||
mesh = maps.Mesh(jax.devices(), ('x',))
|
||||
s = sharding.MeshPspecSharding(mesh, P('x'))
|
||||
x = jnp.arange(len(jax.devices()))
|
||||
|
||||
y = jax.device_put(x, s)
|
||||
self.assertEqual(y.sharding, s)
|
||||
self.assertArraysAllClose(y, x)
|
||||
@ -1508,11 +1510,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(u, y)
|
||||
self.assertEqual(u.device(), jax.devices()[0])
|
||||
|
||||
# TODO(frostig): make this pass with JAX_ENABLE_CUSTOM_PRNG=1
|
||||
# # this can cover opaque dtypes
|
||||
# x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
||||
# jax.device_put(x, s) # doesn't crash
|
||||
|
||||
def test_device_get_scalar(self):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
x = api.device_put(x)
|
||||
|
@ -2197,6 +2197,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, arr)
|
||||
self.assertLen(out.addressable_shards, 8)
|
||||
|
||||
@jax_array(True)
|
||||
def test_pjit_uncommitted_array_in_axis_resources_reshard(self):
|
||||
arr = jnp.arange(16).reshape(8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
with mesh:
|
||||
out = pjit(lambda x: x, in_axis_resources=P('x', 'y'))(arr)
|
||||
self.assertArraysEqual(out, arr)
|
||||
self.assertLen(out.addressable_shards, 8)
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data, arr[s.index])
|
||||
self.assertEqual(s.replica_id, 0)
|
||||
|
||||
@jax_array(True)
|
||||
def test_pjit_uncommitted_array_and_committed_array(self):
|
||||
shape = (8, 2)
|
||||
@ -2291,6 +2303,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
||||
|
||||
@jax_array(True)
|
||||
def test_device_put_sharding_prng(self):
|
||||
mesh = jtu.create_global_mesh((8,), ('x',))
|
||||
s = MeshPspecSharding(mesh, P('x'))
|
||||
|
||||
x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
||||
y = jax.device_put(x, s)
|
||||
|
||||
if config.jax_enable_custom_prng:
|
||||
self.assertIsInstance(y, jax.random.KeyArray)
|
||||
self.assertEqual(y.sharding, s)
|
||||
|
||||
@jax_array(True)
|
||||
def test_device_put_on_different_sharding(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
s1 = MeshPspecSharding(mesh, P('x'))
|
||||
a = jax.device_put(x, s1)
|
||||
self.assertEqual(a.sharding, s1)
|
||||
|
||||
s2 = MeshPspecSharding(mesh, P('x', 'y'))
|
||||
b = jax.device_put(a, s2)
|
||||
self.assertEqual(b.sharding, s2)
|
||||
|
||||
|
||||
class ArrayCppPjitTest(ArrayPjitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user