Merge pull request #1485 from hawkinsp/master

Emit a better error if mismatched axis sizes are passed to pmap.
This commit is contained in:
Peter Hawkins 2019-10-10 13:49:50 -04:00 committed by GitHub
commit dd72fd5db7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -111,7 +111,9 @@ def shard_aval(size, aval):
shard_aval_handlers = {}
shard_aval_handlers[core.AbstractUnit] = lambda size, x: x
def _shard_abstract_array(size, x):
assert x.shape[0] == size
if x.shape[0] != size:
raise ValueError("Axis size {} does not match leading dimension of "
"shape {}".format(size, x.shape))
return ShapedArray(x.shape[1:], x.dtype)
shard_aval_handlers[ShapedArray] = _shard_abstract_array

View File

@ -80,6 +80,13 @@ class PmapTest(jtu.JaxTestCase):
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMismatchedAxisSizes(self):
n = xla_bridge.device_count()
f = pmap(lambda x, y: x + y)
jtu.check_raises_regexp(
lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError,
"Axis size .* does not match leading dimension of shape .*")
@parameterized.named_parameters(
{"testcase_name": "_mesh={}".format(device_mesh_shape),
"device_mesh_shape": device_mesh_shape}