1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Finalize the deprecation of the arr.device() method

The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
This commit is contained in:
Jake VanderPlas 2024-04-08 19:03:17 -07:00 committed by jax authors
parent af3dcd2b12
commit 1b3aea8205
6 changed files with 6 additions and 31 deletions

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
## jaxlib 0.4.27

@ -30,7 +30,6 @@ from jax._src import api_util
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
@ -50,7 +49,6 @@ from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.typing import ArrayLike
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
deprecations.register(__name__, "device-method")
Shape = tuple[int, ...]
Device = xc.Device
@ -471,21 +469,6 @@ class ArrayImpl(basearray.Array):
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
return per_shard_size * len(self.sharding.device_set)
# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
if deprecations.is_accelerated(__name__, "device-method"):
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
else:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
single_device, = device_set
return single_device
raise ValueError('Length of devices is greater than 1. '
'Please use `.devices()`.')
def devices(self) -> set[Device]:
self._check_if_deleted()
return self.sharding.device_set

@ -196,7 +196,6 @@ class Array(abc.ABC):
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> set[Device]: ...
@property
def sharding(self) -> Sharding: ...

@ -849,11 +849,6 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
f"The delete() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def device(self):
raise ConcretizationTypeError(self,
f"The device() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on {self._error_repr()}."

@ -19,6 +19,7 @@ from typing import Any, Callable
import jax
from jax._src.array import ArrayImpl
from jax.experimental.array_api._version import __array_api_version__
from jax.sharding import Sharding
from jax._src.lib import xla_extension as xe
@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None):
return jax.experimental.array_api
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
def _to_device(self, device: xe.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
# The type of device is defined by Array.device. In JAX, this is a callable that
# returns a device, so we must handle this case to satisfy the API spec.
return jax.device_put(self, device() if callable(device) else device)
return jax.device_put(self, device)
def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
setattr(ArrayImpl, "device", property(lambda self: self.sharding))

@ -33,7 +33,6 @@ from jax import numpy as jnp
from jax import random
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap
@ -1019,9 +1018,6 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
if not deprecations.is_accelerated('jax._src.array', 'device-method'):
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
self.assertEqual(key.device(), key._base_array.device())
self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes(),
key._base_array.on_device_size_in_bytes())