From 5a8c12db9fc5bd8be7a963c725323a4b063ac965 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 23 Feb 2023 15:37:13 -0800 Subject: [PATCH] Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input PiperOrigin-RevId: 511905019 --- jax/_src/dispatch.py | 17 ++++++++++++++++- jax/_src/sharding.py | 8 ++++++++ tests/pjit_test.py | 25 +++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 5e2ff7c6f..2e8c6fee5 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -63,7 +63,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.sharding import (PmapSharding, SingleDeviceSharding, GSPMDSharding, NamedSharding, PartitionSpec, - Sharding) + Sharding, XLACompatibleSharding) from jax._src.util import flatten, unflatten @@ -1387,6 +1387,18 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra committed=(device is not None)) +# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that +# to check if shardings are compatible with the input. +def _check_sharding(x, s): + from jax._src import pjit + + if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding): + pjit.pjit_check_aval_sharding( + (s,), (x,), "device_put args", allow_uneven_sharding=False) + + s.shard_shape(x.shape) # should raise an Error if incompatible + + def _device_put_impl( x, device: Optional[Union[Device, jax.sharding.Sharding]] = None): from jax.interpreters import pxla @@ -1409,6 +1421,9 @@ def _device_put_impl( raise ValueError( "device_put's second argument must be a Device or a Sharding which " f"represents addressable devices, but got {s}") + + _check_sharding(x, s) + if getattr(x, 'sharding', None) == s: return x # TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here? diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index afb90a2d0..e5dc198e5 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -518,12 +518,20 @@ class PmapSharding(XLACompatibleSharding): @functools.lru_cache(maxsize=4096) def shard_shape(self, global_shape: Shape) -> Shape: sharded_dim = None + sharded_dim_size = None for i, s in enumerate(self.sharding_spec.sharding): if isinstance(s, pxla.Unstacked): sharded_dim = i + sharded_dim_size = s.size break if sharded_dim is None: return global_shape + if global_shape[sharded_dim] != sharded_dim_size: + raise ValueError( + f'The sharded dimension must be equal to the number of ' + f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} ' + f'with value {global_shape[sharded_dim]} in shape {global_shape} and ' + f'the number of devices={len(self._device_assignment)}') return global_shape[:sharded_dim] + global_shape[sharded_dim+1:] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b9636c996..525cd57b0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3407,6 +3407,31 @@ class ArrayPjitTest(jtu.JaxTestCase): out2.sharding._to_xla_op_sharding(out2.ndim)) self.assertListEqual(ns2, [2, 2, 1, 1]) + def test_device_put_sharding_nondivisible_sharding_error(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + + x = jnp.ones((1,)) + with self.assertRaisesRegex( + ValueError, 'implies that the size of its dimension 0 should be ' + 'divisible by 2, but it is equal to 1 '): + jax.device_put(x, s) + + y = jnp.ones((2,)) + with self.assertRaisesRegex( + ValueError, 'implies that the size of its dimension 0 should be ' + 'divisible by 2, but it is equal to 1 '): + jax.device_put((y, x), s) + + with self.assertRaisesRegex( + ValueError, + "The sharded dimension must be equal to the number of " + "devices passed to PmapSharding. Got sharded dimension 0 with value 1 " + r"in shape \(1,\) and the number of devices=2"): + s2 = jax.pmap(lambda x: x, + devices=list(mesh.devices.flat))(jnp.arange(2)).sharding + jax.device_put(x, s2) + class TempSharding(Sharding):