Merge pull request #6925 from jakevdp:nonzero-test

PiperOrigin-RevId: 378420194
This commit is contained in:
jax authors 2021-06-09 09:10:59 -07:00
commit 8362db6ef8

View File

@ -932,8 +932,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_size={}".format(
jtu.format_shape_dtype_string(shape, dtype), size),
"shape": shape, "dtype": dtype, "size": size}
for shape in nonempty_array_shapes
for dtype in all_dtypes
for size in [1, 5, 10]))
def testNonzeroSize(self, shape, dtype, size):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
return tuple(arg[:size] for arg in result)
else:
return tuple(np.concatenate([arg, np.zeros(size - len(arg), arg.dtype)])
for arg in result)
jnp_fun = lambda x: jnp.nonzero(x, size=size)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(