Merge pull request #4347 from jakevdp:array-equiv

PiperOrigin-RevId: 332946445
This commit is contained in:
jax authors 2020-09-21 15:20:19 -07:00
commit ada6f30f59
4 changed files with 42 additions and 1 deletions

View File

@ -68,6 +68,7 @@ Not every function in NumPy is implemented; contributions are welcome!
around
array
array_equal
array_equiv
array_repr
array_split
array_str

View File

@ -23,7 +23,7 @@ from .lax_numpy import (
alltrue, amax, amin, angle, any, append,
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
array, array_equal, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays,
broadcast_to, can_cast, cbrt, cdouble, ceil, character, clip, column_stack,

View File

@ -2418,6 +2418,20 @@ def array_equal(a1, a2, equal_nan=False):
return all(eq)
@_wraps(np.array_equiv)
def array_equiv(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
try:
eq = equal(a1, a2)
except ValueError:
# shapes are not broadcastable
return False
return all(eq)
# We can't create uninitialized arrays in XLA; use zeros for empty.
empty_like = zeros_like
empty = zeros

View File

@ -142,6 +142,7 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
@ -593,6 +594,31 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaises(TypeError):
op(arg, other)
def testArrayEqualExamples(self):
# examples from the array_equal() docstring.
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
a = np.array([1, np.nan])
self.assertFalse(jnp.array_equal(a, a))
self.assertTrue(jnp.array_equal(a, a, equal_nan=True))
a = np.array([1 + 1j])
b = a.copy()
a.real = np.nan
b.imag = np.nan
self.assertTrue(jnp.array_equal(a, b, equal_nan=True))
def testArrayEquivExamples(self):
# examples from the array_equiv() docstring.
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
def testArrayModule(self):
if numpy_dispatch is None:
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')