diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b0e516e5..bb50d4902 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,10 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.11 +* Changes + * Readded support for the Python buffer protocol (`memoryview`) on CPU + devices. + ## jax 0.4.10 (May 11, 2023) ## jaxlib 0.4.10 (May 11, 2023) diff --git a/tests/array_test.py b/tests/array_test.py index 997b4c9ab..9b405bae8 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -16,6 +16,8 @@ import contextlib import math import os +import unittest + from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -688,6 +690,38 @@ class JaxArrayTest(jtu.JaxTestCase): self.assertEqual(out.shape, x.shape) self.assertArraysEqual(out, x) + @jtu.sample_product( + dtype=[dt for dt in jtu.dtypes.all if dt != jax.dtypes.bfloat16], + shape=[(), (10), (2, 3)], + ) + @unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11") + def test_buffer_protocol(self, dtype, shape): + if jtu.device_under_test() != "cpu": + raise unittest.SkipTest("Buffer protocol only works on CPU") + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + y = jax.device_put(x) + x_bytes = memoryview(x).tobytes() + y_bytes = memoryview(y).tobytes() + self.assertEqual(x_bytes, y_bytes) + + @unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11") + def test_buffer_protocol_deletion(self): + if jtu.device_under_test() != "cpu": + raise unittest.SkipTest("Buffer protocol only works on CPU") + rng = jtu.rand_default(self.rng()) + x = rng((3, 4), np.float32) + y = jax.device_put(x) + x_bytes = memoryview(x).tobytes() + y_view = memoryview(y) + # The array does not actually get deleted until any external reference is + # dropped. Arguably we should make calling delete() in these circumstances + # return an error instead, but that would be a behavior change for existing + # users. + y.delete() + y_bytes = y_view.tobytes() + self.assertEqual(x_bytes, y_bytes) + def test_array_copy_to_host_async(self): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) x = pjit(lambda: jnp.arange(8.), @@ -1119,5 +1153,6 @@ class RngShardingTest(jtu.JaxTestCase): self.assertArraysEqual(y, y_ref1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 834e048f0..6ee2b12db 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3771,7 +3771,8 @@ class PJitErrorTest(jtu.JaxTestCase): x = jax.device_put(inp_data) f = pjit(lambda x: x + 1) _ = f(x) - with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'): + with self.assertRaisesRegex((RuntimeError, ValueError), + '.*(Array|buffer|Buffer) has been deleted.*'): x.delete() _ = f(x)