mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Relax the memory alignment check between numpy array and jax array on CPU
PiperOrigin-RevId: 567722405
This commit is contained in:
parent
6dcda12140
commit
8276038f63
@ -4464,7 +4464,11 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_asarray_no_copy_np(self):
|
||||
x = np.random.uniform(0, 1, (1000, 2000)).astype("float32")
|
||||
out = jnp.asarray(x)
|
||||
self.assertTrue(np.shares_memory(out, x))
|
||||
|
||||
x_ptr = x.__array_interface__["data"][0]
|
||||
# This is because the PJRT CPU client shares memory if it is 16-byte aligned.
|
||||
if (x_ptr & 15) != 0:
|
||||
self.assertTrue(np.shares_memory(out, x))
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user