mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #6915 from jakevdp:argwhere-size
PiperOrigin-RevId: 378275722
This commit is contained in:
commit
30b00095a9
@ -4467,10 +4467,18 @@ def vander(x, N=None, increasing=False):
|
||||
|
||||
### Misc
|
||||
|
||||
_ARGWHERE_DOC = """\
|
||||
Because the size of the output of ``argwhere`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
|
||||
specifies the size of the leading dimension of the output - it must be specified statically
|
||||
for ``jnp.argwhere`` to be traced. If ``size`` is specified, the indices of the first ``size``
|
||||
True elements will be returned; if there are fewer nonzero elements than `size` indicates,
|
||||
the index arrays will be zero-padded.
|
||||
"""
|
||||
|
||||
@_wraps(np.argwhere)
|
||||
def argwhere(a):
|
||||
result = transpose(vstack(nonzero(a)))
|
||||
@_wraps(np.argwhere, lax_description=_ARGWHERE_DOC)
|
||||
def argwhere(a, *, size=None):
|
||||
result = transpose(vstack(nonzero(a, size=size)))
|
||||
if ndim(a) == 0:
|
||||
return result[:0].reshape(result.shape[0], 0)
|
||||
return result.reshape(result.shape[0], ndim(a))
|
||||
|
@ -970,6 +970,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.skipTest("np.argwhere() result for scalar input changed in numpy 1.18.")
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying a size statically. Full test of this
|
||||
# behavior is in testNonzeroSize().
|
||||
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user