jax.numpy reductions: require initial to be a scalar

This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
This commit is contained in:
Jake VanderPlas 2023-02-14 15:36:18 -08:00
parent c2b7c5f132
commit dafb88a649
3 changed files with 9 additions and 6 deletions

View File

@ -30,6 +30,11 @@ Remember to align the itemized text with the first line of an item within a list
from Feb 13, 2023. from Feb 13, 2023.
* added the {mod}`jax.typing` module, with tools for type annotations of JAX * added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions. functions.
* Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcating the output against non-scalar `initial`
values was an unintentional implementation detail ({jax-issue}`#14446`).
## jaxlib 0.4.4 ## jaxlib 0.4.4
* Breaking changes * Breaking changes

View File

@ -131,11 +131,10 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
else: else:
result = lax.reduce(a, init_val, op, dims) result = lax.reduce(a, init_val, op, dims)
if initial is not None: if initial is not None:
# TODO(jakevdp) require initial to be a scalar in order to match the numpy API.
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype) initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape: if initial_arr.shape != ():
raise ValueError(f"initial value has invalid shape {initial_arr.shape} " raise ValueError("initial value must be a scalar. "
f"for reduction with output shape {result.shape}") f"Got array of shape {initial_arr.shape}")
result = op(initial_arr, result) result = op(initial_arr, result)
if keepdims: if keepdims:
result = lax.expand_dims(result, pos_dims) result = lax.expand_dims(result, pos_dims)

View File

@ -287,8 +287,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_op = getattr(jnp, rec.name) jnp_op = getattr(jnp, rec.name)
arr = jnp.ones((2, 3, 4)) arr = jnp.ones((2, 3, 4))
initial = jnp.zeros((1, 2, 3)) initial = jnp.zeros((1, 2, 3))
msg = (r"initial value has invalid shape \(1, 2, 3\) " msg = r"initial value must be a scalar. Got array of shape \(1, 2, 3\)"
r"for reduction with output shape \(2, 3\)")
with self.assertRaisesRegex(ValueError, msg): with self.assertRaisesRegex(ValueError, msg):
jnp_op(arr, axis=-1, initial=initial) jnp_op(arr, axis=-1, initial=initial)