Remove tests for jax.numpy.in1d, which is deprecated.

PiperOrigin-RevId: 561161024
This commit is contained in:
Peter Hawkins 2023-08-29 15:51:57 -07:00 committed by jax authors
parent 93900245aa
commit e369445596

View File

@ -612,20 +612,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
element_shape=all_shapes,
test_shape=all_shapes,
dtype=default_dtypes,
invert=[False, True],
)
def testIn1d(self, element_shape, test_shape, dtype, invert):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],