Add JIT-compatible version of jnp.nonzero

This commit is contained in:
Jake VanderPlas 2021-04-20 09:18:26 -07:00
parent c09037bd14
commit 8d17cce80e
3 changed files with 24 additions and 13 deletions

View File

@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
* Breaking changes:
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments

View File

@ -2259,22 +2259,27 @@ def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
_NONZERO_DOC = """\
At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero`
because its output shape is data-dependent.
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero``
to be traced. If specified, the first `size` nonzero elements will be returned; if there
are fewer nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""
@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
def nonzero(a):
# Note: this function cannot be jitted because its output has a dynamic
# shape.
a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
dims = shape(a)
ndims = len(dims)
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
d = concatenate(ds, axis=-1)
indexes = d[a != 0]
return tuple(indexes[..., i] for i in range(ndims))
def nonzero(a, *, size=None):
a = atleast_1d(a)
mask = a != 0
if size is None:
size = mask.sum()
size = core.concrete_or_error(int, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if a.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in a.shape)
flat_indices = cumsum(bincount(cumsum(mask), length=size))
strides = np.cumprod(a.shape[::-1])[::-1] // a.shape
return tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))
@_wraps(np.flatnonzero)
def flatnonzero(a):

View File

@ -933,6 +933,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),