mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.ufunc: support where argument in ufunc.reduce
This commit is contained in:
parent
d452eea9b6
commit
ac1233b453
@ -83,26 +83,33 @@ class ufunc:
|
||||
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
|
||||
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None):
|
||||
check_arraylike(f"{self.__name__}.reduce", a)
|
||||
if self.nin != 2:
|
||||
raise ValueError("reduce only supported for binary ufuncs")
|
||||
if self.nout != 1:
|
||||
raise ValueError("reduce only supported for functions returning a single value")
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"out argument of {self.__name__}.reduce()")
|
||||
# TODO(jakevdp): implement where.
|
||||
if initial is not None:
|
||||
check_arraylike(f"{self.__name__}.reduce", initial)
|
||||
if where is not None:
|
||||
raise NotImplementedError(f"where argument of {self.__name__}.reduce()")
|
||||
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial)
|
||||
check_arraylike(f"{self.__name__}.reduce", where)
|
||||
if self.identity is None and initial is None:
|
||||
raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, "
|
||||
"so to use a where mask one has to specify 'initial'.")
|
||||
if lax_internal._dtype(where) != bool:
|
||||
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
|
||||
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None):
|
||||
def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
|
||||
assert self.nin == 2 and self.nout == 1
|
||||
check_arraylike(f"{self.__name__}.reduce", arr)
|
||||
arr = lax_internal.asarray(arr)
|
||||
if initial is None:
|
||||
initial = self.identity
|
||||
if dtype is None:
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
|
||||
|
||||
if where is not None:
|
||||
where = _broadcast_to(where, arr.shape)
|
||||
if isinstance(axis, tuple):
|
||||
axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis)
|
||||
raise NotImplementedError("tuple of axes")
|
||||
@ -112,6 +119,8 @@ class ufunc:
|
||||
else:
|
||||
final_shape = ()
|
||||
arr = arr.ravel()
|
||||
if where is not None:
|
||||
where = where.ravel()
|
||||
axis = 0
|
||||
else:
|
||||
axis = canonicalize_axis(axis, arr.ndim)
|
||||
@ -123,23 +132,28 @@ class ufunc:
|
||||
# TODO: handle without transpose?
|
||||
if axis != 0:
|
||||
arr = _moveaxis(arr, axis, 0)
|
||||
if where is not None:
|
||||
where = _moveaxis(where, axis, 0)
|
||||
|
||||
if initial is None and arr.shape[0] == 0:
|
||||
raise ValueError("zero-size array to reduction operation {self.__name__} which has no ideneity")
|
||||
|
||||
def body_fun(i, val):
|
||||
return self._call(val, arr[i].astype(dtype))
|
||||
if where is None:
|
||||
return self._call(val, arr[i].astype(dtype))
|
||||
else:
|
||||
return _where(where[i], self._call(val, arr[i].astype(dtype)), val)
|
||||
|
||||
if initial is None:
|
||||
start = 1
|
||||
initial = arr[0]
|
||||
start_index = 1
|
||||
start_value = arr[0]
|
||||
else:
|
||||
check_arraylike(f"{self.__name__}.reduce", arr)
|
||||
start = 0
|
||||
start_index = 0
|
||||
start_value = initial
|
||||
start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:])
|
||||
|
||||
initial = _broadcast_to(lax_internal.asarray(initial).astype(dtype), arr.shape[1:])
|
||||
result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value)
|
||||
|
||||
result = jax.lax.fori_loop(start, arr.shape[0], body_fun, initial)
|
||||
if keepdims:
|
||||
result = result.reshape(final_shape)
|
||||
return result
|
||||
|
@ -145,7 +145,40 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
|
||||
# Need initial if identity is None
|
||||
initial = 1 if identity is None else None
|
||||
|
||||
def jnp_fun(arr, where):
|
||||
return jnp.frompyfunc(func, nin, nout, identity=identity).reduce(
|
||||
arr, where=where, axis=axis, initial=initial)
|
||||
|
||||
@cast_outputs
|
||||
def np_fun(arr, where):
|
||||
# Workaround for https://github.com/numpy/numpy/issues/24530
|
||||
# TODO(jakevdp): remove this when possible.
|
||||
initial_workaround = identity if initial is None else initial
|
||||
return np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce(
|
||||
arr, where=where, axis=axis, initial=initial_workaround)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_where = jtu.rand_bool(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
|
Loading…
x
Reference in New Issue
Block a user