mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context. PiperOrigin-RevId: 623015500
This commit is contained in:
parent
af3dcd2b12
commit
1b3aea8205
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
|
||||
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
|
||||
`max` ({jax-issue}`20550`).
|
||||
* The `device()` method of JAX arrays has been removed, after being deprecated
|
||||
since JAX v0.4.21. Use `arr.devices()` instead.
|
||||
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
@ -30,7 +30,6 @@ from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
@ -50,7 +49,6 @@ from jax._src.layout import DeviceLocalLayout, Layout
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
|
||||
deprecations.register(__name__, "device-method")
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
@ -471,21 +469,6 @@ class ArrayImpl(basearray.Array):
|
||||
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
|
||||
return per_shard_size * len(self.sharding.device_set)
|
||||
|
||||
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
||||
def device(self) -> Device:
|
||||
if deprecations.is_accelerated(__name__, "device-method"):
|
||||
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
|
||||
else:
|
||||
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
device_set = self.sharding.device_set
|
||||
if len(device_set) == 1:
|
||||
single_device, = device_set
|
||||
return single_device
|
||||
raise ValueError('Length of devices is greater than 1. '
|
||||
'Please use `.devices()`.')
|
||||
|
||||
def devices(self) -> set[Device]:
|
||||
self._check_if_deleted()
|
||||
return self.sharding.device_set
|
||||
|
@ -196,7 +196,6 @@ class Array(abc.ABC):
|
||||
def block_until_ready(self) -> Array: ...
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
def delete(self) -> None: ...
|
||||
def device(self) -> Device: ...
|
||||
def devices(self) -> set[Device]: ...
|
||||
@property
|
||||
def sharding(self) -> Sharding: ...
|
||||
|
@ -849,11 +849,6 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
f"The delete() method was called on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
def device(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The device() method was called on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
def devices(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The devices() method was called on {self._error_repr()}."
|
||||
|
@ -19,6 +19,7 @@ from typing import Any, Callable
|
||||
import jax
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax.experimental.array_api._version import __array_api_version__
|
||||
from jax.sharding import Sharding
|
||||
|
||||
from jax._src.lib import xla_extension as xe
|
||||
|
||||
@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None):
|
||||
return jax.experimental.array_api
|
||||
|
||||
|
||||
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
|
||||
def _to_device(self, device: xe.Device | Sharding | None, *,
|
||||
stream: int | Any | None = None):
|
||||
if stream is not None:
|
||||
raise NotImplementedError("stream argument of array.to_device()")
|
||||
# The type of device is defined by Array.device. In JAX, this is a callable that
|
||||
# returns a device, so we must handle this case to satisfy the API spec.
|
||||
return jax.device_put(self, device() if callable(device) else device)
|
||||
return jax.device_put(self, device)
|
||||
|
||||
|
||||
def add_array_object_methods():
|
||||
# TODO(jakevdp): set on tracers as well?
|
||||
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
|
||||
setattr(ArrayImpl, "to_device", _to_device)
|
||||
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
|
||||
|
@ -33,7 +33,6 @@ from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
@ -1019,9 +1018,6 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
|
||||
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
|
||||
if not deprecations.is_accelerated('jax._src.array', 'device-method'):
|
||||
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
|
||||
self.assertEqual(key.device(), key._base_array.device())
|
||||
self.assertEqual(key.devices(), key._base_array.devices())
|
||||
self.assertEqual(key.on_device_size_in_bytes(),
|
||||
key._base_array.on_device_size_in_bytes())
|
||||
|
Loading…
x
Reference in New Issue
Block a user