Deprecate arr.device_buffer and arr.device_buffers

This commit is contained in:
Jake VanderPlas 2023-12-06 10:20:29 -08:00
parent 4bdcb11bd6
commit 35b84402c0
6 changed files with 24 additions and 15 deletions

View File

@ -8,6 +8,13 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.22
* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
## jaxlib 0.4.22
## jax 0.4.21 (Dec 4 2023)

View File

@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "device_buffers"):
buffers.extend(leaf.device_buffers)
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])
try:
dispatch.check_special(pjit.pjit_p.name, buffers)

View File

@ -474,6 +474,10 @@ class ArrayImpl(basearray.Array):
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
@ -484,6 +488,10 @@ class ArrayImpl(basearray.Array):
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
return cast(Sequence[ArrayImpl], self._arrays)

View File

@ -1112,17 +1112,11 @@ JIT_IMPLEMENTATION = (
)
class BufferDonationTestCase(JaxTestCase):
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
def assertDeleted(self, x):
self.assertTrue(x.is_deleted())
def _assertDeleted(self, x, deleted):
if hasattr(x, "_arrays"):
self.assertEqual(x.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
def assertNotDeleted(self, x):
self.assertFalse(x.is_deleted())
@contextmanager

View File

@ -93,7 +93,7 @@
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"device = arr.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
],

View File

@ -929,7 +929,7 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
in_axes=['i', None], out_axes=[None],
axis_resources={'i': 'x'})
h = pjit(f, in_shardings=P('x', None), out_shardings=P(None))(x)
assert (h.addressable_data(0) == x.reshape(8)).all()
self.assertArraysEqual(h.addressable_data(0), x.reshape(8))
@parameterized.named_parameters(
{'testcase_name': name, 'mesh': mesh}
@ -949,7 +949,7 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
out_shardings=P('x', None),
)(x)
assert (h.addressable_data(0).reshape(4) == x[0, :]*2).all()
self.assertArraysEqual(h.addressable_data(0).reshape(4), x[0, :] * 2)
@jtu.with_mesh([('x', 2)])
def testBareXmapCollective(self):