diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 84b865aec..92b7b304a 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -72,7 +72,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where_: Optional[ArrayLike] = None, - parallel_reduce: Optional[Callable[..., ArrayLike]] = None, + parallel_reduce: Optional[Callable[..., Array]] = None, promote_integers: bool = False) -> Array: bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to @@ -131,7 +131,12 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: else: result = lax.reduce(a, init_val, op, dims) if initial is not None: - result = op(lax.convert_element_type(initial, _asarray(a).dtype), result) + # TODO(jakevdp) require initial to be a scalar in order to match the numpy API. + initial_arr = lax.convert_element_type(initial, _asarray(a).dtype) + if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape: + raise ValueError(f"initial value has invalid shape {initial_arr.shape} " + f"for reduction with output shape {result.shape}") + result = op(initial_arr, result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index e89324ce1..c5f559d2a 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -282,6 +282,16 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) + @jtu.sample_product(rec = JAX_REDUCER_INITIAL_RECORDS) + def testReducerBadInitial(self, rec): + jnp_op = getattr(jnp, rec.name) + arr = jnp.ones((2, 3, 4)) + initial = jnp.zeros((1, 2, 3)) + msg = (r"initial value has invalid shape \(1, 2, 3\) " + r"for reduction with output shape \(2, 3\)") + with self.assertRaisesRegex(ValueError, msg): + jnp_op(arr, axis=-1, initial=initial) + @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],