mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add JIT-compatible version of jnp.nonzero
This commit is contained in:
parent
c09037bd14
commit
8d17cce80e
@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
|
||||
keyword arguments. A new `static_argnames` option has been added to specify
|
||||
keyword arguments as static.
|
||||
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
||||
be used within `jit` ({jax-issue}`6501`)
|
||||
* Breaking changes:
|
||||
* Arguments to {func}`jax.jit` other than the function are now marked as
|
||||
keyword-only. This change is to prevent accidental breakage when arguments
|
||||
|
@ -2259,22 +2259,27 @@ def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
|
||||
|
||||
_NONZERO_DOC = """\
|
||||
At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero`
|
||||
because its output shape is data-dependent.
|
||||
Because the size of the output of ``nonzero`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional `size` argument which
|
||||
specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero``
|
||||
to be traced. If specified, the first `size` nonzero elements will be returned; if there
|
||||
are fewer nonzero elements than `size` indicates, the index arrays will be zero-padded.
|
||||
"""
|
||||
|
||||
@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
|
||||
def nonzero(a):
|
||||
# Note: this function cannot be jitted because its output has a dynamic
|
||||
# shape.
|
||||
a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
|
||||
dims = shape(a)
|
||||
ndims = len(dims)
|
||||
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
|
||||
d = concatenate(ds, axis=-1)
|
||||
indexes = d[a != 0]
|
||||
return tuple(indexes[..., i] for i in range(ndims))
|
||||
|
||||
def nonzero(a, *, size=None):
|
||||
a = atleast_1d(a)
|
||||
mask = a != 0
|
||||
if size is None:
|
||||
size = mask.sum()
|
||||
size = core.concrete_or_error(int, size,
|
||||
"The size argument of jnp.nonzero must be statically specified "
|
||||
"to use jnp.nonzero within JAX transformations.")
|
||||
if a.size == 0 or size == 0:
|
||||
return tuple(zeros(size, int) for dim in a.shape)
|
||||
flat_indices = cumsum(bincount(cumsum(mask), length=size))
|
||||
strides = np.cumprod(a.shape[::-1])[::-1] // a.shape
|
||||
return tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))
|
||||
|
||||
@_wraps(np.flatnonzero)
|
||||
def flatnonzero(a):
|
||||
|
@ -933,6 +933,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying the size statically:
|
||||
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user