mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #17992 from jakevdp:object-test
PiperOrigin-RevId: 571442631
This commit is contained in:
commit
cbf2cdc1bc
@ -931,10 +931,21 @@ class JaxTestCase(parameterized.TestCase):
|
||||
def rng(self):
|
||||
return self._rng
|
||||
|
||||
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg=''):
|
||||
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', allow_object_dtype=False):
|
||||
"""Assert that x and y arrays are exactly equal."""
|
||||
if check_dtypes:
|
||||
self.assertDtypesMatch(x, y)
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
|
||||
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
|
||||
# See https://github.com/google/jax/issues/17867
|
||||
raise TypeError(
|
||||
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
|
||||
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "
|
||||
"If comparing collections of arrays, consider using assertAllClose. "
|
||||
"To let this test proceed anyway, pass allow_object_dtype=True.")
|
||||
|
||||
# Work around https://github.com/numpy/numpy/issues/18992
|
||||
with np.errstate(over='ignore'):
|
||||
np.testing.assert_array_equal(x, y, err_msg=err_msg)
|
||||
|
@ -423,7 +423,8 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
|
||||
y = jax.pmap(jnp.sin)(x)
|
||||
self.assertArraysEqual([a.device() for a in y],
|
||||
y.sharding._device_assignment)
|
||||
y.sharding._device_assignment,
|
||||
allow_object_dtype=True)
|
||||
|
||||
sin_x = iter(np.sin(x))
|
||||
for i, j in zip(iter(y), sin_x):
|
||||
@ -783,7 +784,8 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
device_assignment = mp_sharding._device_assignment
|
||||
|
||||
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
|
||||
self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
|
||||
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
|
||||
allow_object_dtype=True)
|
||||
self.assertTrue(hlo_sharding.is_tiled())
|
||||
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
|
||||
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
|
||||
|
@ -184,7 +184,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.save(f, arr)
|
||||
f.seek(0)
|
||||
arr_out = jnp.load(f, allow_pickle=allow_pickle)
|
||||
self.assertArraysEqual(arr, arr_out)
|
||||
self.assertArraysEqual(arr, arr_out, allow_object_dtype=True)
|
||||
|
||||
def testArrayEqualExamples(self):
|
||||
# examples from the array_equal() docstring.
|
||||
|
Loading…
x
Reference in New Issue
Block a user