From 8d17cce80e6950c60a19e82434e821bef0166542 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 Apr 2021 09:18:26 -0700 Subject: [PATCH] Add JIT-compatible version of jnp.nonzero --- CHANGELOG.md | 2 ++ jax/_src/numpy/lax_numpy.py | 31 ++++++++++++++++++------------- tests/lax_numpy_test.py | 4 ++++ 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e4c03199..b6ca821b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static keyword arguments. A new `static_argnames` option has been added to specify keyword arguments as static. + * {func}`jax.nonzero` has a new optional `size` argument that allows it to + be used within `jit` ({jax-issue}`6501`) * Breaking changes: * Arguments to {func}`jax.jit` other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e38488d9c..24f81586d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2259,22 +2259,27 @@ def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, _NONZERO_DOC = """\ -At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero` -because its output shape is data-dependent. +Because the size of the output of ``nonzero`` is data-dependent, the function is not +typically compatible with JIT. The JAX version adds the optional `size` argument which +specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero`` +to be traced. If specified, the first `size` nonzero elements will be returned; if there +are fewer nonzero elements than `size` indicates, the index arrays will be zero-padded. """ @_wraps(np.nonzero, lax_description=_NONZERO_DOC) -def nonzero(a): - # Note: this function cannot be jitted because its output has a dynamic - # shape. - a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero") - dims = shape(a) - ndims = len(dims) - ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)] - d = concatenate(ds, axis=-1) - indexes = d[a != 0] - return tuple(indexes[..., i] for i in range(ndims)) - +def nonzero(a, *, size=None): + a = atleast_1d(a) + mask = a != 0 + if size is None: + size = mask.sum() + size = core.concrete_or_error(int, size, + "The size argument of jnp.nonzero must be statically specified " + "to use jnp.nonzero within JAX transformations.") + if a.size == 0 or size == 0: + return tuple(zeros(size, int) for dim in a.shape) + flat_indices = cumsum(bincount(cumsum(mask), length=size)) + strides = np.cumprod(a.shape[::-1])[::-1] // a.shape + return tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape)) @_wraps(np.flatnonzero) def flatnonzero(a): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b3466afff..17f05fd41 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -933,6 +933,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + # JIT compilation requires specifying the size statically: + jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2) + self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)),