diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 4520129f0..48db32b3b 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -68,6 +68,7 @@ Not every function in NumPy is implemented; contributions are welcome! around array array_equal + array_equiv array_repr array_split array_str diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 197d9eb50..566d13e36 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -23,7 +23,7 @@ from .lax_numpy import ( alltrue, amax, amin, angle, any, append, apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around, - array, array_equal, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d, + array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d, atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_to, can_cast, cbrt, cdouble, ceil, character, clip, column_stack, diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 8cca5e300..9d416a565 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2418,6 +2418,20 @@ def array_equal(a1, a2, equal_nan=False): return all(eq) +@_wraps(np.array_equiv) +def array_equiv(a1, a2): + try: + a1, a2 = asarray(a1), asarray(a2) + except Exception: + return False + try: + eq = equal(a1, a2) + except ValueError: + # shapes are not broadcastable + return False + return all(eq) + + # We can't create uninitialized arrays in XLA; use zeros for empty. empty_like = zeros_like empty = zeros diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 4e012b9fc..09470cc45 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -142,6 +142,7 @@ JAX_ONE_TO_ONE_OP_RECORDS = [ all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes, @@ -593,6 +594,31 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): with self.assertRaises(TypeError): op(arg, other) + def testArrayEqualExamples(self): + # examples from the array_equal() docstring. + self.assertTrue(jnp.array_equal([1, 2], [1, 2])) + self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) + self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) + self.assertFalse(jnp.array_equal([1, 2], [1, 4])) + + a = np.array([1, np.nan]) + self.assertFalse(jnp.array_equal(a, a)) + self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) + + a = np.array([1 + 1j]) + b = a.copy() + a.real = np.nan + b.imag = np.nan + self.assertTrue(jnp.array_equal(a, b, equal_nan=True)) + + def testArrayEquivExamples(self): + # examples from the array_equiv() docstring. + self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) + self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) + self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) + def testArrayModule(self): if numpy_dispatch is None: raise SkipTest('requires https://github.com/seberg/numpy-dispatch')