mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #4347 from jakevdp:array-equiv
PiperOrigin-RevId: 332946445
This commit is contained in:
commit
ada6f30f59
@ -68,6 +68,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
around
|
||||
array
|
||||
array_equal
|
||||
array_equiv
|
||||
array_repr
|
||||
array_split
|
||||
array_str
|
||||
|
@ -23,7 +23,7 @@ from .lax_numpy import (
|
||||
alltrue, amax, amin, angle, any, append,
|
||||
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
|
||||
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
|
||||
array, array_equal, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
|
||||
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
|
||||
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
|
||||
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays,
|
||||
broadcast_to, can_cast, cbrt, cdouble, ceil, character, clip, column_stack,
|
||||
|
@ -2418,6 +2418,20 @@ def array_equal(a1, a2, equal_nan=False):
|
||||
return all(eq)
|
||||
|
||||
|
||||
@_wraps(np.array_equiv)
|
||||
def array_equiv(a1, a2):
|
||||
try:
|
||||
a1, a2 = asarray(a1), asarray(a2)
|
||||
except Exception:
|
||||
return False
|
||||
try:
|
||||
eq = equal(a1, a2)
|
||||
except ValueError:
|
||||
# shapes are not broadcastable
|
||||
return False
|
||||
return all(eq)
|
||||
|
||||
|
||||
# We can't create uninitialized arrays in XLA; use zeros for empty.
|
||||
empty_like = zeros_like
|
||||
empty = zeros
|
||||
|
@ -142,6 +142,7 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
||||
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
|
||||
@ -593,6 +594,31 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
op(arg, other)
|
||||
|
||||
def testArrayEqualExamples(self):
|
||||
# examples from the array_equal() docstring.
|
||||
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
|
||||
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
|
||||
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
|
||||
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
|
||||
|
||||
a = np.array([1, np.nan])
|
||||
self.assertFalse(jnp.array_equal(a, a))
|
||||
self.assertTrue(jnp.array_equal(a, a, equal_nan=True))
|
||||
|
||||
a = np.array([1 + 1j])
|
||||
b = a.copy()
|
||||
a.real = np.nan
|
||||
b.imag = np.nan
|
||||
self.assertTrue(jnp.array_equal(a, b, equal_nan=True))
|
||||
|
||||
def testArrayEquivExamples(self):
|
||||
# examples from the array_equiv() docstring.
|
||||
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
|
||||
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
|
||||
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
|
||||
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
|
||||
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
|
||||
|
||||
def testArrayModule(self):
|
||||
if numpy_dispatch is None:
|
||||
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')
|
||||
|
Loading…
x
Reference in New Issue
Block a user