Finalize deprecation of arr.device_buffer and arr.device_buffers

PiperOrigin-RevId: 627899901
This commit is contained in:
Jake VanderPlas 2024-04-24 17:26:38 -07:00 committed by jax authors
parent 66190d10e7
commit cbe48cad1e
3 changed files with 9 additions and 24 deletions

View File

@ -55,6 +55,9 @@ Remember to align the itemized text with the first line of an item within a list
related functions now raise an error, following a similar change in NumPy.
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.

View File

@ -21,7 +21,6 @@ import operator as op
import numpy as np
import functools
from typing import Any, Callable, cast, TYPE_CHECKING
import warnings
from collections.abc import Sequence
from jax._src import abstract_arrays
@ -479,30 +478,15 @@ class ArrayImpl(basearray.Array):
self._check_if_deleted()
return self.sharding.device_set
# TODO(https://github.com/google/jax/issues/12380): Remove this when DA is
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
raise ValueError('Length of buffers is greater than 1. Please use '
'`.device_buffers` instead.')
def device_buffer(self):
raise AttributeError(
"arr.device_buffer has been deprecated. Use arr.addressable_data(0)")
# TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
return cast(Sequence[ArrayImpl], self._arrays)
def device_buffers(self):
raise AttributeError(
"arr.device_buffers has been deprecated. Use [x.data for x in arr.addressable_shards]")
def addressable_data(self, index: int) -> ArrayImpl:
self._check_if_deleted()

View File

@ -212,8 +212,6 @@ class Array(abc.ABC):
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
@property
def device_buffers(self) -> Any: ...
ArrayLike = Union[