mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate arr.device_buffer and arr.device_buffers
This commit is contained in:
parent
4bdcb11bd6
commit
35b84402c0
@ -8,6 +8,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.22
|
||||
|
||||
* Deprecations
|
||||
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
|
||||
Explicit buffers have been replaced by the more flexible array sharding interface,
|
||||
but the previous outputs can be recovered this way:
|
||||
* `arr.device_buffer` becomes `arr.addressable_data(0)`
|
||||
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
|
||||
|
||||
## jaxlib 0.4.22
|
||||
|
||||
## jax 0.4.21 (Dec 4 2023)
|
||||
|
@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
|
||||
buffers = []
|
||||
for leaf in tree_leaves(output):
|
||||
if hasattr(leaf, "device_buffers"):
|
||||
buffers.extend(leaf.device_buffers)
|
||||
if hasattr(leaf, "addressable_shards"):
|
||||
buffers.extend([shard.data for shard in leaf.addressable_shards])
|
||||
|
||||
try:
|
||||
dispatch.check_special(pjit.pjit_p.name, buffers)
|
||||
|
@ -474,6 +474,10 @@ class ArrayImpl(basearray.Array):
|
||||
# deleted.
|
||||
@property
|
||||
def device_buffer(self) -> ArrayImpl:
|
||||
# Added 2023 Dec 6
|
||||
warnings.warn(
|
||||
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
if len(self._arrays) == 1:
|
||||
return self._arrays[0]
|
||||
@ -484,6 +488,10 @@ class ArrayImpl(basearray.Array):
|
||||
# deleted.
|
||||
@property
|
||||
def device_buffers(self) -> Sequence[ArrayImpl]:
|
||||
# Added 2023 Dec 6
|
||||
warnings.warn(
|
||||
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
return cast(Sequence[ArrayImpl], self._arrays)
|
||||
|
||||
|
@ -1112,17 +1112,11 @@ JIT_IMPLEMENTATION = (
|
||||
)
|
||||
|
||||
class BufferDonationTestCase(JaxTestCase):
|
||||
assertDeleted = lambda self, x: self._assertDeleted(x, True)
|
||||
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
|
||||
def assertDeleted(self, x):
|
||||
self.assertTrue(x.is_deleted())
|
||||
|
||||
def _assertDeleted(self, x, deleted):
|
||||
if hasattr(x, "_arrays"):
|
||||
self.assertEqual(x.is_deleted(), deleted)
|
||||
elif hasattr(x, "device_buffer"):
|
||||
self.assertEqual(x.device_buffer.is_deleted(), deleted)
|
||||
else:
|
||||
for buffer in x.device_buffers:
|
||||
self.assertEqual(buffer.is_deleted(), deleted)
|
||||
def assertNotDeleted(self, x):
|
||||
self.assertFalse(x.is_deleted())
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -93,7 +93,7 @@
|
||||
"import jax\n",
|
||||
"key = jax.random.PRNGKey(1701)\n",
|
||||
"arr = jax.random.normal(key, (1000,))\n",
|
||||
"device = arr.device_buffer.device()\n",
|
||||
"device = arr.device()\n",
|
||||
"print(f\"JAX device type: {device}\")\n",
|
||||
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
|
||||
],
|
||||
|
@ -929,7 +929,7 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
in_axes=['i', None], out_axes=[None],
|
||||
axis_resources={'i': 'x'})
|
||||
h = pjit(f, in_shardings=P('x', None), out_shardings=P(None))(x)
|
||||
assert (h.addressable_data(0) == x.reshape(8)).all()
|
||||
self.assertArraysEqual(h.addressable_data(0), x.reshape(8))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{'testcase_name': name, 'mesh': mesh}
|
||||
@ -949,7 +949,7 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
out_shardings=P('x', None),
|
||||
)(x)
|
||||
|
||||
assert (h.addressable_data(0).reshape(4) == x[0, :]*2).all()
|
||||
self.assertArraysEqual(h.addressable_data(0).reshape(4), x[0, :] * 2)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testBareXmapCollective(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user