diff --git a/CHANGELOG.md b/CHANGELOG.md index 102692289..bca684356 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,9 @@ Remember to align the itemized text with the first line of an item within a list related functions now raise an error, following a similar change in NumPy. * The config option `jax_cpu_enable_gloo_collectives` is deprecated. Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead. + * The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have + been removed after being deprecated in JAX v0.4.22. Instead use + {attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`. * Bug fixes * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. diff --git a/jax/_src/array.py b/jax/_src/array.py index 94238cfb5..fbabc7dd8 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -21,7 +21,6 @@ import operator as op import numpy as np import functools from typing import Any, Callable, cast, TYPE_CHECKING -import warnings from collections.abc import Sequence from jax._src import abstract_arrays @@ -479,30 +478,15 @@ class ArrayImpl(basearray.Array): self._check_if_deleted() return self.sharding.device_set - # TODO(https://github.com/google/jax/issues/12380): Remove this when DA is - # deleted. @property - def device_buffer(self) -> ArrayImpl: - # Added 2023 Dec 6 - warnings.warn( - "arr.device_buffer is deprecated. Use arr.addressable_data(0)", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - if len(self._arrays) == 1: - return self._arrays[0] - raise ValueError('Length of buffers is greater than 1. Please use ' - '`.device_buffers` instead.') + def device_buffer(self): + raise AttributeError( + "arr.device_buffer has been deprecated. Use arr.addressable_data(0)") - # TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is - # deleted. @property - def device_buffers(self) -> Sequence[ArrayImpl]: - # Added 2023 Dec 6 - warnings.warn( - "arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - return cast(Sequence[ArrayImpl], self._arrays) + def device_buffers(self): + raise AttributeError( + "arr.device_buffers has been deprecated. Use [x.data for x in arr.addressable_shards]") def addressable_data(self, index: int) -> ArrayImpl: self._check_if_deleted() diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 6aacbc93b..5eb4e9e5c 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -212,8 +212,6 @@ class Array(abc.ABC): @property def traceback(self) -> Traceback: ... def unsafe_buffer_pointer(self) -> int: ... - @property - def device_buffers(self) -> Any: ... ArrayLike = Union[