Fix test failure in array_interoperability_test due to 64-bit dtype squashing.

PiperOrigin-RevId: 542026628
This commit is contained in:
Peter Hawkins 2023-06-20 13:11:23 -07:00 committed by jax authors
parent 8f83da6a4c
commit b842e868e0

View File

@ -181,9 +181,10 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)
z = np.asarray(y)
a = y.__cuda_array_interface__
self.assertEqual(shape, a["shape"])
self.assertEqual(x.__array_interface__["typestr"], a["typestr"])
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
def testCudaArrayInterfaceBfloat16Fails(self):
rng = jtu.rand_default(self.rng())