mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
PiperOrigin-RevId: 511905019
This commit is contained in:
parent
b5026207bc
commit
5a8c12db9f
@ -63,7 +63,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
GSPMDSharding, NamedSharding, PartitionSpec,
|
||||
Sharding)
|
||||
Sharding, XLACompatibleSharding)
|
||||
from jax._src.util import flatten, unflatten
|
||||
|
||||
|
||||
@ -1387,6 +1387,18 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
||||
committed=(device is not None))
|
||||
|
||||
|
||||
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
|
||||
# to check if shardings are compatible with the input.
|
||||
def _check_sharding(x, s):
|
||||
from jax._src import pjit
|
||||
|
||||
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
|
||||
pjit.pjit_check_aval_sharding(
|
||||
(s,), (x,), "device_put args", allow_uneven_sharding=False)
|
||||
|
||||
s.shard_shape(x.shape) # should raise an Error if incompatible
|
||||
|
||||
|
||||
def _device_put_impl(
|
||||
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
||||
from jax.interpreters import pxla
|
||||
@ -1409,6 +1421,9 @@ def _device_put_impl(
|
||||
raise ValueError(
|
||||
"device_put's second argument must be a Device or a Sharding which "
|
||||
f"represents addressable devices, but got {s}")
|
||||
|
||||
_check_sharding(x, s)
|
||||
|
||||
if getattr(x, 'sharding', None) == s:
|
||||
return x
|
||||
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
|
||||
|
@ -518,12 +518,20 @@ class PmapSharding(XLACompatibleSharding):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
sharded_dim = None
|
||||
sharded_dim_size = None
|
||||
for i, s in enumerate(self.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
sharded_dim = i
|
||||
sharded_dim_size = s.size
|
||||
break
|
||||
if sharded_dim is None:
|
||||
return global_shape
|
||||
if global_shape[sharded_dim] != sharded_dim_size:
|
||||
raise ValueError(
|
||||
f'The sharded dimension must be equal to the number of '
|
||||
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
|
||||
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
|
||||
f'the number of devices={len(self._device_assignment)}')
|
||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||
|
||||
|
||||
|
@ -3407,6 +3407,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out2.sharding._to_xla_op_sharding(out2.ndim))
|
||||
self.assertListEqual(ns2, [2, 2, 1, 1])
|
||||
|
||||
def test_device_put_sharding_nondivisible_sharding_error(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
x = jnp.ones((1,))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'implies that the size of its dimension 0 should be '
|
||||
'divisible by 2, but it is equal to 1 '):
|
||||
jax.device_put(x, s)
|
||||
|
||||
y = jnp.ones((2,))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'implies that the size of its dimension 0 should be '
|
||||
'divisible by 2, but it is equal to 1 '):
|
||||
jax.device_put((y, x), s)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The sharded dimension must be equal to the number of "
|
||||
"devices passed to PmapSharding. Got sharded dimension 0 with value 1 "
|
||||
r"in shape \(1,\) and the number of devices=2"):
|
||||
s2 = jax.pmap(lambda x: x,
|
||||
devices=list(mesh.devices.flat))(jnp.arange(2)).sharding
|
||||
jax.device_put(x, s2)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user