mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array. The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do. Fixes https://github.com/google/jax/issues/14713 PiperOrigin-RevId: 535248553
This commit is contained in:
parent
6b13d4eb86
commit
e464dc8700
@ -25,6 +25,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.11
|
||||
|
||||
* Changes
|
||||
* Readded support for the Python buffer protocol (`memoryview`) on CPU
|
||||
devices.
|
||||
|
||||
## jax 0.4.10 (May 11, 2023)
|
||||
|
||||
## jaxlib 0.4.10 (May 11, 2023)
|
||||
|
@ -16,6 +16,8 @@
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -688,6 +690,38 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, x.shape)
|
||||
self.assertArraysEqual(out, x)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=[dt for dt in jtu.dtypes.all if dt != jax.dtypes.bfloat16],
|
||||
shape=[(), (10), (2, 3)],
|
||||
)
|
||||
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
|
||||
def test_buffer_protocol(self, dtype, shape):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Buffer protocol only works on CPU")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
y = jax.device_put(x)
|
||||
x_bytes = memoryview(x).tobytes()
|
||||
y_bytes = memoryview(y).tobytes()
|
||||
self.assertEqual(x_bytes, y_bytes)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
|
||||
def test_buffer_protocol_deletion(self):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Buffer protocol only works on CPU")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((3, 4), np.float32)
|
||||
y = jax.device_put(x)
|
||||
x_bytes = memoryview(x).tobytes()
|
||||
y_view = memoryview(y)
|
||||
# The array does not actually get deleted until any external reference is
|
||||
# dropped. Arguably we should make calling delete() in these circumstances
|
||||
# return an error instead, but that would be a behavior change for existing
|
||||
# users.
|
||||
y.delete()
|
||||
y_bytes = y_view.tobytes()
|
||||
self.assertEqual(x_bytes, y_bytes)
|
||||
|
||||
def test_array_copy_to_host_async(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
x = pjit(lambda: jnp.arange(8.),
|
||||
@ -1119,5 +1153,6 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(y, y_ref1)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -3771,7 +3771,8 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
x = jax.device_put(inp_data)
|
||||
f = pjit(lambda x: x + 1)
|
||||
_ = f(x)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
||||
with self.assertRaisesRegex((RuntimeError, ValueError),
|
||||
'.*(Array|buffer|Buffer) has been deleted.*'):
|
||||
x.delete()
|
||||
_ = f(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user