Add more complete test for jnp.nonzero size argument

This commit is contained in:
Jake VanderPlas 2021-06-08 16:40:53 -07:00
parent 30b00095a9
commit 0f4f4102ce

View File

@ -932,8 +932,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_size={}".format(
jtu.format_shape_dtype_string(shape, dtype), size),
"shape": shape, "dtype": dtype, "size": size}
for shape in nonempty_array_shapes
for dtype in all_dtypes
for size in [1, 5, 10]))
def testNonzeroSize(self, shape, dtype, size):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
return tuple(arg[:size] for arg in result)
else:
return tuple(np.concatenate([arg, np.zeros(size - len(arg), arg.dtype)])
for arg in result)
jnp_fun = lambda x: jnp.nonzero(x, size=size)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(