mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #14365 from jakevdp:reducers-initial
PiperOrigin-RevId: 509253981
This commit is contained in:
commit
c49af18b9b
@ -72,7 +72,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
axis: Axis = None, dtype: DTypeLike = None, out: None = None,
|
||||
keepdims: bool = False, initial: Optional[ArrayLike] = None,
|
||||
where_: Optional[ArrayLike] = None,
|
||||
parallel_reduce: Optional[Callable[..., ArrayLike]] = None,
|
||||
parallel_reduce: Optional[Callable[..., Array]] = None,
|
||||
promote_integers: bool = False) -> Array:
|
||||
bool_op = bool_op or op
|
||||
# Note: we must accept out=None as an argument, because numpy reductions delegate to
|
||||
@ -131,7 +131,12 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
result = op(lax.convert_element_type(initial, _asarray(a).dtype), result)
|
||||
# TODO(jakevdp) require initial to be a scalar in order to match the numpy API.
|
||||
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
|
||||
if lax.broadcast_shapes(initial_arr.shape, result.shape) != result.shape:
|
||||
raise ValueError(f"initial value has invalid shape {initial_arr.shape} "
|
||||
f"for reduction with output shape {result.shape}")
|
||||
result = op(initial_arr, result)
|
||||
if keepdims:
|
||||
result = lax.expand_dims(result, pos_dims)
|
||||
return lax.convert_element_type(result, dtype or result_dtype)
|
||||
|
@ -282,6 +282,16 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(rec = JAX_REDUCER_INITIAL_RECORDS)
|
||||
def testReducerBadInitial(self, rec):
|
||||
jnp_op = getattr(jnp, rec.name)
|
||||
arr = jnp.ones((2, 3, 4))
|
||||
initial = jnp.zeros((1, 2, 3))
|
||||
msg = (r"initial value has invalid shape \(1, 2, 3\) "
|
||||
r"for reduction with output shape \(2, 3\)")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jnp_op(arr, axis=-1, initial=initial)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user