mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
PiperOrigin-RevId: 521507810
This commit is contained in:
parent
05249ec770
commit
6f2256ad17
@ -438,6 +438,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
self.shard_shape(global_shape) # raises a good error message
|
||||
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
||||
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
||||
|
||||
@ -633,6 +634,7 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
self.shard_shape(global_shape) # raises a good error message
|
||||
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
||||
len(self._devices))
|
||||
return dict(safe_zip(self._devices, indices))
|
||||
|
@ -937,6 +937,16 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertTrue(s9.is_equivalent_to(s10, 2))
|
||||
|
||||
def test_devices_indices_map_good_error_message(self):
|
||||
shape = (1, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Sharding.*implies that array axis 0 is partitioned 2 times, but the "
|
||||
"dimension size is 1"):
|
||||
s.devices_indices_map(shape)
|
||||
|
||||
|
||||
class RngShardingTest(jtu.JaxTestCase):
|
||||
# tests that the PRNGs are automatically sharded as expected
|
||||
|
Loading…
x
Reference in New Issue
Block a user