mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Remove tests for jax.numpy.in1d, which is deprecated.
PiperOrigin-RevId: 561161024
This commit is contained in:
parent
93900245aa
commit
e369445596
@ -612,20 +612,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
element_shape=all_shapes,
|
||||
test_shape=all_shapes,
|
||||
dtype=default_dtypes,
|
||||
invert=[False, True],
|
||||
)
|
||||
def testIn1d(self, element_shape, test_shape, dtype, invert):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
|
||||
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
|
||||
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
||||
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
||||
|
Loading…
x
Reference in New Issue
Block a user