mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
testOgrid: make test compatible with NumPy 2.0
This commit is contained in:
parent
0d152dcfab
commit
bd5e9bef33
@ -4803,9 +4803,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testOgrid(self):
|
||||
# wrap indexer for appropriate dtype defaults.
|
||||
np_ogrid = _indexer_with_default_outputs(np.ogrid)
|
||||
def assertListOfArraysEqual(xs, ys):
|
||||
self.assertIsInstance(xs, list)
|
||||
self.assertIsInstance(ys, list)
|
||||
def assertSequenceOfArraysEqual(xs, ys):
|
||||
self.assertIsInstance(xs, (list, tuple))
|
||||
self.assertIsInstance(ys, (list, tuple))
|
||||
self.assertEqual(len(xs), len(ys))
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertArraysEqual(x, y)
|
||||
@ -4814,10 +4814,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
|
||||
self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2])
|
||||
# List of arrays
|
||||
assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
|
||||
assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
|
||||
assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
|
||||
assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
|
||||
assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
|
||||
assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
|
||||
assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
|
||||
assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
|
||||
# Corner cases
|
||||
self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:])
|
||||
# Complex number steps
|
||||
|
Loading…
x
Reference in New Issue
Block a user