mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix test failure in array_interoperability_test due to 64-bit dtype squashing.
PiperOrigin-RevId: 542026628
This commit is contained in:
parent
8f83da6a4c
commit
b842e868e0
@ -181,9 +181,10 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
y = jnp.array(x)
|
||||
z = np.asarray(y)
|
||||
a = y.__cuda_array_interface__
|
||||
self.assertEqual(shape, a["shape"])
|
||||
self.assertEqual(x.__array_interface__["typestr"], a["typestr"])
|
||||
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
|
||||
|
||||
def testCudaArrayInterfaceBfloat16Fails(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user