Merge pull request #17992 from jakevdp:object-test

PiperOrigin-RevId: 571442631
This commit is contained in:
jax authors 2023-10-06 14:59:39 -07:00
commit cbf2cdc1bc
3 changed files with 17 additions and 4 deletions

View File

@ -931,10 +931,21 @@ class JaxTestCase(parameterized.TestCase):
def rng(self):
return self._rng
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg=''):
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', allow_object_dtype=False):
"""Assert that x and y arrays are exactly equal."""
if check_dtypes:
self.assertDtypesMatch(x, y)
x = np.asarray(x)
y = np.asarray(y)
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
# See https://github.com/google/jax/issues/17867
raise TypeError(
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "
"If comparing collections of arrays, consider using assertAllClose. "
"To let this test proceed anyway, pass allow_object_dtype=True.")
# Work around https://github.com/numpy/numpy/issues/18992
with np.errstate(over='ignore'):
np.testing.assert_array_equal(x, y, err_msg=err_msg)

View File

@ -423,7 +423,8 @@ class JaxArrayTest(jtu.JaxTestCase):
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
y.sharding._device_assignment)
y.sharding._device_assignment,
allow_object_dtype=True)
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
@ -783,7 +784,8 @@ class ShardingTest(jtu.JaxTestCase):
device_assignment = mp_sharding._device_assignment
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
allow_object_dtype=True)
self.assertTrue(hlo_sharding.is_tiled())
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
self.assertListEqual(hlo_sharding.tile_assignment_devices(),

View File

@ -184,7 +184,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.save(f, arr)
f.seek(0)
arr_out = jnp.load(f, allow_pickle=allow_pickle)
self.assertArraysEqual(arr, arr_out)
self.assertArraysEqual(arr, arr_out, allow_object_dtype=True)
def testArrayEqualExamples(self):
# examples from the array_equal() docstring.