mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6925 from jakevdp:nonzero-test
PiperOrigin-RevId: 378420194
This commit is contained in:
commit
8362db6ef8
@ -932,8 +932,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying the size statically:
|
||||
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), size),
|
||||
"shape": shape, "dtype": dtype, "size": size}
|
||||
for shape in nonempty_array_shapes
|
||||
for dtype in all_dtypes
|
||||
for size in [1, 5, 10]))
|
||||
def testNonzeroSize(self, shape, dtype, size):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
||||
def np_fun(x):
|
||||
result = np.nonzero(x)
|
||||
if size <= len(result[0]):
|
||||
return tuple(arg[:size] for arg in result)
|
||||
else:
|
||||
return tuple(np.concatenate([arg, np.zeros(size - len(arg), arg.dtype)])
|
||||
for arg in result)
|
||||
jnp_fun = lambda x: jnp.nonzero(x, size=size)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user