From dafb88a6495a6587f720a2ce5daa9ff27c2acd54 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 Feb 2023 15:36:18 -0800 Subject: [PATCH] jax.numpy reductions: require initial to be a scalar This follows the requirements of numpy's reduction API. Non-scalar initial values can be implemented via . --- CHANGELOG.md | 5 +++++ jax/_src/numpy/reductions.py | 7 +++---- tests/lax_numpy_reducers_test.py | 3 +-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7617b74b6..ac82cff3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,11 @@ Remember to align the itemized text with the first line of an item within a list from Feb 13, 2023. * added the {mod}`jax.typing` module, with tools for type annotations of JAX functions. +* Breaking Changes + * the `initial` argument to reduction functions like :func:`jax.numpy.sum` + is now required to be a scalar, consistent with the corresponding NumPy API. + The previous behavior of broadcating the output against non-scalar `initial` + values was an unintentional implementation detail ({jax-issue}`#14446`). ## jaxlib 0.4.4 * Breaking changes diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 92b7b304a..32d5677f9 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -131,11 +131,10 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: else: result = lax.reduce(a, init_val, op, dims) if initial is not None: - # TODO(jakevdp) require initial to be a scalar in order to match the numpy API. initial_arr = lax.convert_element_type(initial, _asarray(a).dtype) - if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape: - raise ValueError(f"initial value has invalid shape {initial_arr.shape} " - f"for reduction with output shape {result.shape}") + if initial_arr.shape != (): + raise ValueError("initial value must be a scalar. " + f"Got array of shape {initial_arr.shape}") result = op(initial_arr, result) if keepdims: result = lax.expand_dims(result, pos_dims) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index c5f559d2a..f68137771 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -287,8 +287,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): jnp_op = getattr(jnp, rec.name) arr = jnp.ones((2, 3, 4)) initial = jnp.zeros((1, 2, 3)) - msg = (r"initial value has invalid shape \(1, 2, 3\) " - r"for reduction with output shape \(2, 3\)") + msg = r"initial value must be a scalar. Got array of shape \(1, 2, 3\)" with self.assertRaisesRegex(ValueError, msg): jnp_op(arr, axis=-1, initial=initial)