mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1485 from hawkinsp/master
Emit a better error if mismatched axis sizes are passed to pmap.
This commit is contained in:
commit
dd72fd5db7
@ -111,7 +111,9 @@ def shard_aval(size, aval):
|
||||
shard_aval_handlers = {}
|
||||
shard_aval_handlers[core.AbstractUnit] = lambda size, x: x
|
||||
def _shard_abstract_array(size, x):
|
||||
assert x.shape[0] == size
|
||||
if x.shape[0] != size:
|
||||
raise ValueError("Axis size {} does not match leading dimension of "
|
||||
"shape {}".format(size, x.shape))
|
||||
return ShapedArray(x.shape[1:], x.dtype)
|
||||
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
|
||||
|
@ -80,6 +80,13 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMismatchedAxisSizes(self):
|
||||
n = xla_bridge.device_count()
|
||||
f = pmap(lambda x, y: x + y)
|
||||
jtu.check_raises_regexp(
|
||||
lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError,
|
||||
"Axis size .* does not match leading dimension of shape .*")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_mesh={}".format(device_mesh_shape),
|
||||
"device_mesh_shape": device_mesh_shape}
|
||||
|
Loading…
x
Reference in New Issue
Block a user