fix reduction repeated axis error (#3618)

* fix reduction repeated axis error

* deflake
This commit is contained in:
Matthew Johnson 2020-06-30 21:18:46 -07:00 committed by GitHub
parent e808681f6c
commit eb2a227588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 1 deletions

View File

@ -4201,7 +4201,9 @@ _masking_defreducer(reduce_sum_p,
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
del input_shape # unused.
del input_shape # Unused.
if len(axes) != len(set(axes)):
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
return tuple(onp.delete(operand.shape, axes))
def _reduce_prod_translation_rule(c, operand, *, axes):

View File

@ -1550,6 +1550,8 @@ def _reduction_dims(a, axis):
if axis is None:
return tuple(range(ndim(a)))
elif isinstance(axis, (np.ndarray, tuple, list)):
if len(axis) != len(set(axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
elif isinstance(axis, int):
return (_canonicalize_axis(axis, ndim(a)),)

View File

@ -3821,6 +3821,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
self.assertEqual(jnp.asarray(v).tolist(), v.tolist())
def testReductionWithRepeatedAxisError(self):
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
jnp.sum(jnp.arange(3), (0, 0))
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.

View File

@ -1714,6 +1714,11 @@ class LaxTest(jtu.JaxTestCase):
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')
def test_reduction_with_repeated_axes_error(self):
with self.assertRaisesRegex(ValueError, "duplicate value in 'axes' .*"):
lax.reduce(onp.arange(3), 0, lax.add, (0, 0))
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
# check casting to ndarray works