testOgrid: make test compatible with NumPy 2.0

This commit is contained in:
Jake VanderPlas 2024-01-29 09:21:23 -08:00
parent 0d152dcfab
commit bd5e9bef33

View File

@ -4803,9 +4803,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testOgrid(self):
# wrap indexer for appropriate dtype defaults.
np_ogrid = _indexer_with_default_outputs(np.ogrid)
def assertListOfArraysEqual(xs, ys):
self.assertIsInstance(xs, list)
self.assertIsInstance(ys, list)
def assertSequenceOfArraysEqual(xs, ys):
self.assertIsInstance(xs, (list, tuple))
self.assertIsInstance(ys, (list, tuple))
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
self.assertArraysEqual(x, y)
@ -4814,10 +4814,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2])
# List of arrays
assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
# Corner cases
self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:])
# Complex number steps