mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix reduction repeated axis error (#3618)
* fix reduction repeated axis error * deflake
This commit is contained in:
parent
e808681f6c
commit
eb2a227588
@ -4201,7 +4201,9 @@ _masking_defreducer(reduce_sum_p,
|
||||
|
||||
|
||||
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
||||
del input_shape # unused.
|
||||
del input_shape # Unused.
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
|
||||
return tuple(onp.delete(operand.shape, axes))
|
||||
|
||||
def _reduce_prod_translation_rule(c, operand, *, axes):
|
||||
|
@ -1550,6 +1550,8 @@ def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return tuple(range(ndim(a)))
|
||||
elif isinstance(axis, (np.ndarray, tuple, list)):
|
||||
if len(axis) != len(set(axis)):
|
||||
raise ValueError(f"duplicate value in 'axis': {axis}")
|
||||
return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
|
||||
elif isinstance(axis, int):
|
||||
return (_canonicalize_axis(axis, ndim(a)),)
|
||||
|
@ -3821,6 +3821,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
v = np.arange(12, dtype=np.int32).reshape(3, 4)
|
||||
self.assertEqual(jnp.asarray(v).tolist(), v.tolist())
|
||||
|
||||
def testReductionWithRepeatedAxisError(self):
|
||||
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
|
||||
jnp.sum(jnp.arange(3), (0, 0))
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
@ -1714,6 +1714,11 @@ class LaxTest(jtu.JaxTestCase):
|
||||
TypeError, "Argument .* of type .* is not a valid JAX type"):
|
||||
lax.add(1, 'hi')
|
||||
|
||||
def test_reduction_with_repeated_axes_error(self):
|
||||
with self.assertRaisesRegex(ValueError, "duplicate value in 'axes' .*"):
|
||||
lax.reduce(onp.arange(3), 0, lax.add, (0, 0))
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
# check casting to ndarray works
|
||||
|
Loading…
x
Reference in New Issue
Block a user