Merge pull request #6915 from jakevdp:argwhere-size

PiperOrigin-RevId: 378275722
This commit is contained in:
jax authors 2021-06-08 16:39:57 -07:00
commit 30b00095a9
2 changed files with 16 additions and 3 deletions

View File

@ -4467,10 +4467,18 @@ def vander(x, N=None, increasing=False):
### Misc
_ARGWHERE_DOC = """\
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be traced. If ``size`` is specified, the indices of the first ``size``
True elements will be returned; if there are fewer nonzero elements than `size` indicates,
the index arrays will be zero-padded.
"""
@_wraps(np.argwhere)
def argwhere(a):
result = transpose(vstack(nonzero(a)))
@_wraps(np.argwhere, lax_description=_ARGWHERE_DOC)
def argwhere(a, *, size=None):
result = transpose(vstack(nonzero(a, size=size)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))

View File

@ -970,6 +970,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.skipTest("np.argwhere() result for scalar input changed in numpy 1.18.")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}".format(
rec.test_name.capitalize(),