Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error

PiperOrigin-RevId: 521507810
This commit is contained in:
Yash Katariya 2023-04-03 11:04:43 -07:00 committed by jax authors
parent 05249ec770
commit 6f2256ad17
2 changed files with 12 additions and 0 deletions

View File

@ -438,6 +438,7 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@ -633,6 +634,7 @@ class GSPMDSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))

View File

@ -937,6 +937,16 @@ class ShardingTest(jtu.JaxTestCase):
self.assertTrue(s9.is_equivalent_to(s10, 2))
def test_devices_indices_map_good_error_message(self):
shape = (1, 2)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
"Sharding.*implies that array axis 0 is partitioned 2 times, but the "
"dimension size is 1"):
s.devices_indices_map(shape)
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected