mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
This commit is contained in:
parent
0fbc453d94
commit
07f5d7a475
@ -43,7 +43,6 @@ from jax._src.numpy import array_api_metadata
|
||||
from jax._src.numpy import indexing
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import tensor_contractions
|
||||
from jax._src.numpy import util
|
||||
from jax._src.pjit import PartitionSpec
|
||||
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
|
||||
from jax._src.numpy import reductions
|
||||
@ -160,17 +159,6 @@ def _conjugate(self: Array) -> Array:
|
||||
"""
|
||||
return ufuncs.conjugate(self)
|
||||
|
||||
def _contains(self: Array, other: ArrayLike) -> Array:
|
||||
"""Implements __contains__ for JAX arrays."""
|
||||
if self.ndim != 1:
|
||||
raise ValueError("Array.__contains__: search array must be one-dimensional,"
|
||||
f" got arr.shape={self.shape}.")
|
||||
query = util.ensure_arraylike('Array.__contains__', other)
|
||||
if query.ndim != 0:
|
||||
raise ValueError("Array.__contains__: query value must be a scalar,"
|
||||
f" got {query.shape=}")
|
||||
return reductions.any(self == query)
|
||||
|
||||
def _copy(self: Array) -> Array:
|
||||
"""Return a copy of the array.
|
||||
|
||||
@ -944,7 +932,6 @@ _array_operators = {
|
||||
"getitem": _getitem,
|
||||
"setitem": _unimplemented_setitem,
|
||||
"copy": _copy,
|
||||
"contains": _contains,
|
||||
"deepcopy": _deepcopy,
|
||||
"neg": lambda self: ufuncs.negative(self),
|
||||
"pos": lambda self: ufuncs.positive(self),
|
||||
|
@ -3757,18 +3757,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
y = jax.vmap(f)(x)
|
||||
self.assertIsNot(x, y)
|
||||
|
||||
def testArrayContains(self):
|
||||
self.assertTrue(1 in jnp.arange(4))
|
||||
self.assertFalse(100 in jnp.arange(4))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Array.__contains__: search array must be one-dimensional"):
|
||||
1 in jnp.array(1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Array.__contains__: query value must be a scalar"):
|
||||
jnp.arange(2) in jnp.arange(2)
|
||||
|
||||
def testArrayUnsupportedDtypeError(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'JAX only supports number, bool, and string dtypes.*'
|
||||
|
Loading…
x
Reference in New Issue
Block a user