Reland: [XLA:Python] Add buffer protocol support to jax.Array

We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
This commit is contained in:
Peter Hawkins 2023-05-25 07:19:56 -07:00 committed by jax authors
parent 6b13d4eb86
commit e464dc8700
3 changed files with 41 additions and 1 deletions

View File

@ -25,6 +25,10 @@ Remember to align the itemized text with the first line of an item within a list
## jaxlib 0.4.11
* Changes
* Readded support for the Python buffer protocol (`memoryview`) on CPU
devices.
## jax 0.4.10 (May 11, 2023)
## jaxlib 0.4.10 (May 11, 2023)

View File

@ -16,6 +16,8 @@
import contextlib
import math
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -688,6 +690,38 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertEqual(out.shape, x.shape)
self.assertArraysEqual(out, x)
@jtu.sample_product(
dtype=[dt for dt in jtu.dtypes.all if dt != jax.dtypes.bfloat16],
shape=[(), (10), (2, 3)],
)
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
def test_buffer_protocol(self, dtype, shape):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Buffer protocol only works on CPU")
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jax.device_put(x)
x_bytes = memoryview(x).tobytes()
y_bytes = memoryview(y).tobytes()
self.assertEqual(x_bytes, y_bytes)
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
def test_buffer_protocol_deletion(self):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Buffer protocol only works on CPU")
rng = jtu.rand_default(self.rng())
x = rng((3, 4), np.float32)
y = jax.device_put(x)
x_bytes = memoryview(x).tobytes()
y_view = memoryview(y)
# The array does not actually get deleted until any external reference is
# dropped. Arguably we should make calling delete() in these circumstances
# return an error instead, but that would be a behavior change for existing
# users.
y.delete()
y_bytes = y_view.tobytes()
self.assertEqual(x_bytes, y_bytes)
def test_array_copy_to_host_async(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
x = pjit(lambda: jnp.arange(8.),
@ -1119,5 +1153,6 @@ class RngShardingTest(jtu.JaxTestCase):
self.assertArraysEqual(y, y_ref1)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -3771,7 +3771,8 @@ class PJitErrorTest(jtu.JaxTestCase):
x = jax.device_put(inp_data)
f = pjit(lambda x: x + 1)
_ = f(x)
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
with self.assertRaisesRegex((RuntimeError, ValueError),
'.*(Array|buffer|Buffer) has been deleted.*'):
x.delete()
_ = f(x)