mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of arr.device_buffer
and arr.device_buffers
PiperOrigin-RevId: 627899901
This commit is contained in:
parent
66190d10e7
commit
cbe48cad1e
@ -55,6 +55,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
related functions now raise an error, following a similar change in NumPy.
|
||||
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
|
||||
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
|
||||
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
|
||||
been removed after being deprecated in JAX v0.4.22. Instead use
|
||||
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
|
||||
|
||||
* Bug fixes
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
|
@ -21,7 +21,6 @@ import operator as op
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
@ -479,30 +478,15 @@ class ArrayImpl(basearray.Array):
|
||||
self._check_if_deleted()
|
||||
return self.sharding.device_set
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/12380): Remove this when DA is
|
||||
# deleted.
|
||||
@property
|
||||
def device_buffer(self) -> ArrayImpl:
|
||||
# Added 2023 Dec 6
|
||||
warnings.warn(
|
||||
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
if len(self._arrays) == 1:
|
||||
return self._arrays[0]
|
||||
raise ValueError('Length of buffers is greater than 1. Please use '
|
||||
'`.device_buffers` instead.')
|
||||
def device_buffer(self):
|
||||
raise AttributeError(
|
||||
"arr.device_buffer has been deprecated. Use arr.addressable_data(0)")
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is
|
||||
# deleted.
|
||||
@property
|
||||
def device_buffers(self) -> Sequence[ArrayImpl]:
|
||||
# Added 2023 Dec 6
|
||||
warnings.warn(
|
||||
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
return cast(Sequence[ArrayImpl], self._arrays)
|
||||
def device_buffers(self):
|
||||
raise AttributeError(
|
||||
"arr.device_buffers has been deprecated. Use [x.data for x in arr.addressable_shards]")
|
||||
|
||||
def addressable_data(self, index: int) -> ArrayImpl:
|
||||
self._check_if_deleted()
|
||||
|
@ -212,8 +212,6 @@ class Array(abc.ABC):
|
||||
@property
|
||||
def traceback(self) -> Traceback: ...
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
@property
|
||||
def device_buffers(self) -> Any: ...
|
||||
|
||||
|
||||
ArrayLike = Union[
|
||||
|
Loading…
x
Reference in New Issue
Block a user