mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values can be implemented via .
This commit is contained in:
parent
c2b7c5f132
commit
dafb88a649
@ -30,6 +30,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
from Feb 13, 2023.
|
||||
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
||||
functions.
|
||||
* Breaking Changes
|
||||
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
|
||||
is now required to be a scalar, consistent with the corresponding NumPy API.
|
||||
The previous behavior of broadcating the output against non-scalar `initial`
|
||||
values was an unintentional implementation detail ({jax-issue}`#14446`).
|
||||
|
||||
## jaxlib 0.4.4
|
||||
* Breaking changes
|
||||
|
@ -131,11 +131,10 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
# TODO(jakevdp) require initial to be a scalar in order to match the numpy API.
|
||||
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
|
||||
if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape:
|
||||
raise ValueError(f"initial value has invalid shape {initial_arr.shape} "
|
||||
f"for reduction with output shape {result.shape}")
|
||||
if initial_arr.shape != ():
|
||||
raise ValueError("initial value must be a scalar. "
|
||||
f"Got array of shape {initial_arr.shape}")
|
||||
result = op(initial_arr, result)
|
||||
if keepdims:
|
||||
result = lax.expand_dims(result, pos_dims)
|
||||
|
@ -287,8 +287,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
jnp_op = getattr(jnp, rec.name)
|
||||
arr = jnp.ones((2, 3, 4))
|
||||
initial = jnp.zeros((1, 2, 3))
|
||||
msg = (r"initial value has invalid shape \(1, 2, 3\) "
|
||||
r"for reduction with output shape \(2, 3\)")
|
||||
msg = r"initial value must be a scalar. Got array of shape \(1, 2, 3\)"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jnp_op(arr, axis=-1, initial=initial)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user