Relax the memory alignment check between numpy array and jax array on CPU

PiperOrigin-RevId: 567722405
This commit is contained in:
Yash Katariya 2023-09-22 14:48:15 -07:00 committed by jax authors
parent 6dcda12140
commit 8276038f63

View File

@ -4464,7 +4464,11 @@ class APITest(jtu.JaxTestCase):
def test_asarray_no_copy_np(self):
x = np.random.uniform(0, 1, (1000, 2000)).astype("float32")
out = jnp.asarray(x)
self.assertTrue(np.shares_memory(out, x))
x_ptr = x.__array_interface__["data"][0]
# This is because the PJRT CPU client shares memory if it is 16-byte aligned.
if (x_ptr & 15) != 0:
self.assertTrue(np.shares_memory(out, x))
class RematTest(jtu.JaxTestCase):