mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values can be implemented via .
This commit is contained in:
parent
c2b7c5f132
commit
dafb88a649
@ -30,6 +30,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
from Feb 13, 2023.
|
from Feb 13, 2023.
|
||||||
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
||||||
functions.
|
functions.
|
||||||
|
* Breaking Changes
|
||||||
|
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
|
||||||
|
is now required to be a scalar, consistent with the corresponding NumPy API.
|
||||||
|
The previous behavior of broadcating the output against non-scalar `initial`
|
||||||
|
values was an unintentional implementation detail ({jax-issue}`#14446`).
|
||||||
|
|
||||||
## jaxlib 0.4.4
|
## jaxlib 0.4.4
|
||||||
* Breaking changes
|
* Breaking changes
|
||||||
|
@ -131,11 +131,10 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
|||||||
else:
|
else:
|
||||||
result = lax.reduce(a, init_val, op, dims)
|
result = lax.reduce(a, init_val, op, dims)
|
||||||
if initial is not None:
|
if initial is not None:
|
||||||
# TODO(jakevdp) require initial to be a scalar in order to match the numpy API.
|
|
||||||
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
|
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
|
||||||
if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape:
|
if initial_arr.shape != ():
|
||||||
raise ValueError(f"initial value has invalid shape {initial_arr.shape} "
|
raise ValueError("initial value must be a scalar. "
|
||||||
f"for reduction with output shape {result.shape}")
|
f"Got array of shape {initial_arr.shape}")
|
||||||
result = op(initial_arr, result)
|
result = op(initial_arr, result)
|
||||||
if keepdims:
|
if keepdims:
|
||||||
result = lax.expand_dims(result, pos_dims)
|
result = lax.expand_dims(result, pos_dims)
|
||||||
|
@ -287,8 +287,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
jnp_op = getattr(jnp, rec.name)
|
jnp_op = getattr(jnp, rec.name)
|
||||||
arr = jnp.ones((2, 3, 4))
|
arr = jnp.ones((2, 3, 4))
|
||||||
initial = jnp.zeros((1, 2, 3))
|
initial = jnp.zeros((1, 2, 3))
|
||||||
msg = (r"initial value has invalid shape \(1, 2, 3\) "
|
msg = r"initial value must be a scalar. Got array of shape \(1, 2, 3\)"
|
||||||
r"for reduction with output shape \(2, 3\)")
|
|
||||||
with self.assertRaisesRegex(ValueError, msg):
|
with self.assertRaisesRegex(ValueError, msg):
|
||||||
jnp_op(arr, axis=-1, initial=initial)
|
jnp_op(arr, axis=-1, initial=initial)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user