mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
PiperOrigin-RevId: 511905019
This commit is contained in:
parent
b5026207bc
commit
5a8c12db9f
@ -63,7 +63,7 @@ from jax._src.lib import xla_client as xc
|
|||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||||
GSPMDSharding, NamedSharding, PartitionSpec,
|
GSPMDSharding, NamedSharding, PartitionSpec,
|
||||||
Sharding)
|
Sharding, XLACompatibleSharding)
|
||||||
from jax._src.util import flatten, unflatten
|
from jax._src.util import flatten, unflatten
|
||||||
|
|
||||||
|
|
||||||
@ -1387,6 +1387,18 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
|||||||
committed=(device is not None))
|
committed=(device is not None))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
|
||||||
|
# to check if shardings are compatible with the input.
|
||||||
|
def _check_sharding(x, s):
|
||||||
|
from jax._src import pjit
|
||||||
|
|
||||||
|
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
|
||||||
|
pjit.pjit_check_aval_sharding(
|
||||||
|
(s,), (x,), "device_put args", allow_uneven_sharding=False)
|
||||||
|
|
||||||
|
s.shard_shape(x.shape) # should raise an Error if incompatible
|
||||||
|
|
||||||
|
|
||||||
def _device_put_impl(
|
def _device_put_impl(
|
||||||
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
||||||
from jax.interpreters import pxla
|
from jax.interpreters import pxla
|
||||||
@ -1409,6 +1421,9 @@ def _device_put_impl(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"device_put's second argument must be a Device or a Sharding which "
|
"device_put's second argument must be a Device or a Sharding which "
|
||||||
f"represents addressable devices, but got {s}")
|
f"represents addressable devices, but got {s}")
|
||||||
|
|
||||||
|
_check_sharding(x, s)
|
||||||
|
|
||||||
if getattr(x, 'sharding', None) == s:
|
if getattr(x, 'sharding', None) == s:
|
||||||
return x
|
return x
|
||||||
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
|
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
|
||||||
|
@ -518,12 +518,20 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||||
sharded_dim = None
|
sharded_dim = None
|
||||||
|
sharded_dim_size = None
|
||||||
for i, s in enumerate(self.sharding_spec.sharding):
|
for i, s in enumerate(self.sharding_spec.sharding):
|
||||||
if isinstance(s, pxla.Unstacked):
|
if isinstance(s, pxla.Unstacked):
|
||||||
sharded_dim = i
|
sharded_dim = i
|
||||||
|
sharded_dim_size = s.size
|
||||||
break
|
break
|
||||||
if sharded_dim is None:
|
if sharded_dim is None:
|
||||||
return global_shape
|
return global_shape
|
||||||
|
if global_shape[sharded_dim] != sharded_dim_size:
|
||||||
|
raise ValueError(
|
||||||
|
f'The sharded dimension must be equal to the number of '
|
||||||
|
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
|
||||||
|
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
|
||||||
|
f'the number of devices={len(self._device_assignment)}')
|
||||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||||
|
|
||||||
|
|
||||||
|
@ -3407,6 +3407,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
out2.sharding._to_xla_op_sharding(out2.ndim))
|
out2.sharding._to_xla_op_sharding(out2.ndim))
|
||||||
self.assertListEqual(ns2, [2, 2, 1, 1])
|
self.assertListEqual(ns2, [2, 2, 1, 1])
|
||||||
|
|
||||||
|
def test_device_put_sharding_nondivisible_sharding_error(self):
|
||||||
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||||
|
s = NamedSharding(mesh, P('x'))
|
||||||
|
|
||||||
|
x = jnp.ones((1,))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'implies that the size of its dimension 0 should be '
|
||||||
|
'divisible by 2, but it is equal to 1 '):
|
||||||
|
jax.device_put(x, s)
|
||||||
|
|
||||||
|
y = jnp.ones((2,))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'implies that the size of its dimension 0 should be '
|
||||||
|
'divisible by 2, but it is equal to 1 '):
|
||||||
|
jax.device_put((y, x), s)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"The sharded dimension must be equal to the number of "
|
||||||
|
"devices passed to PmapSharding. Got sharded dimension 0 with value 1 "
|
||||||
|
r"in shape \(1,\) and the number of devices=2"):
|
||||||
|
s2 = jax.pmap(lambda x: x,
|
||||||
|
devices=list(mesh.devices.flat))(jnp.arange(2)).sharding
|
||||||
|
jax.device_put(x, s2)
|
||||||
|
|
||||||
|
|
||||||
class TempSharding(Sharding):
|
class TempSharding(Sharding):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user