Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input

PiperOrigin-RevId: 511905019
This commit is contained in:
Yash Katariya 2023-02-23 15:37:13 -08:00 committed by jax authors
parent b5026207bc
commit 5a8c12db9f
3 changed files with 49 additions and 1 deletions

View File

@ -63,7 +63,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
GSPMDSharding, NamedSharding, PartitionSpec,
Sharding)
Sharding, XLACompatibleSharding)
from jax._src.util import flatten, unflatten
@ -1387,6 +1387,18 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
committed=(device is not None))
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(x, s):
from jax._src import pjit
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (x,), "device_put args", allow_uneven_sharding=False)
s.shard_shape(x.shape) # should raise an Error if incompatible
def _device_put_impl(
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
from jax.interpreters import pxla
@ -1409,6 +1421,9 @@ def _device_put_impl(
raise ValueError(
"device_put's second argument must be a Device or a Sharding which "
f"represents addressable devices, but got {s}")
_check_sharding(x, s)
if getattr(x, 'sharding', None) == s:
return x
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?

View File

@ -518,12 +518,20 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
sharded_dim_size = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
sharded_dim = i
sharded_dim_size = s.size
break
if sharded_dim is None:
return global_shape
if global_shape[sharded_dim] != sharded_dim_size:
raise ValueError(
f'The sharded dimension must be equal to the number of '
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
f'the number of devices={len(self._device_assignment)}')
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]

View File

@ -3407,6 +3407,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2.sharding._to_xla_op_sharding(out2.ndim))
self.assertListEqual(ns2, [2, 2, 1, 1])
def test_device_put_sharding_nondivisible_sharding_error(self):
mesh = jtu.create_global_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))
x = jnp.ones((1,))
with self.assertRaisesRegex(
ValueError, 'implies that the size of its dimension 0 should be '
'divisible by 2, but it is equal to 1 '):
jax.device_put(x, s)
y = jnp.ones((2,))
with self.assertRaisesRegex(
ValueError, 'implies that the size of its dimension 0 should be '
'divisible by 2, but it is equal to 1 '):
jax.device_put((y, x), s)
with self.assertRaisesRegex(
ValueError,
"The sharded dimension must be equal to the number of "
"devices passed to PmapSharding. Got sharded dimension 0 with value 1 "
r"in shape \(1,\) and the number of devices=2"):
s2 = jax.pmap(lambda x: x,
devices=list(mesh.devices.flat))(jnp.arange(2)).sharding
jax.device_put(x, s2)
class TempSharding(Sharding):