Merge pull request #14365 from jakevdp:reducers-initial

PiperOrigin-RevId: 509253981
This commit is contained in:
jax authors 2023-02-13 09:43:46 -08:00
commit c49af18b9b
2 changed files with 17 additions and 2 deletions

View File

@ -72,7 +72,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where_: Optional[ArrayLike] = None,
parallel_reduce: Optional[Callable[..., ArrayLike]] = None,
parallel_reduce: Optional[Callable[..., Array]] = None,
promote_integers: bool = False) -> Array:
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
@ -131,7 +131,12 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
result = op(lax.convert_element_type(initial, _asarray(a).dtype), result)
# TODO(jakevdp) require initial to be a scalar in order to match the numpy API.
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape:
raise ValueError(f"initial value has invalid shape {initial_arr.shape} "
f"for reduction with output shape {result.shape}")
result = op(initial_arr, result)
if keepdims:
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)

View File

@ -282,6 +282,16 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
@jtu.sample_product(rec = JAX_REDUCER_INITIAL_RECORDS)
def testReducerBadInitial(self, rec):
jnp_op = getattr(jnp, rec.name)
arr = jnp.ones((2, 3, 4))
initial = jnp.zeros((1, 2, 3))
msg = (r"initial value has invalid shape \(1, 2, 3\) "
r"for reduction with output shape \(2, 3\)")
with self.assertRaisesRegex(ValueError, msg):
jnp_op(arr, axis=-1, initial=initial)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],