[array API] add device property & to_device method

This commit is contained in:
Jake VanderPlas 2024-07-23 09:48:51 -07:00
parent 13e42ad420
commit 613a00044c
9 changed files with 69 additions and 9 deletions

View File

@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,

View File

@ -254,6 +254,13 @@ class ArrayImpl(basearray.Array):
def sharding(self):
return self._sharding
@property
def device(self):
self._check_if_deleted()
if isinstance(self.sharding, SingleDeviceSharding):
return list(self.sharding.device_set)[0]
return self.sharding
@property
def weak_type(self):
return self.aval.weak_type

View File

@ -22,6 +22,7 @@ from typing import Any, Union
from collections.abc import Sequence
# TODO(jakevdp): fix import cycles and define these.
Device = Any
Shard = Any
Sharding = Any
@ -112,6 +113,15 @@ class Array(abc.ABC):
def sharding(self) -> Sharding:
"""The sharding for the array."""
@property
@abc.abstractmethod
def device(self) -> Device | Sharding:
"""Array API-compatible device attribute.
For single-device arrays, this returns a Device. For sharded arrays, this
returns a Sharding.
"""
Array.__module__ = "jax"

View File

@ -204,6 +204,8 @@ class Array(abc.ABC):
@property
def sharding(self) -> Sharding: ...
@property
def device(self) -> Device | Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
@ -216,6 +218,7 @@ class Array(abc.ABC):
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
StaticScalar = Union[

View File

@ -738,6 +738,15 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def addressable_shards(self):
raise ConcretizationTypeError(self,

View File

@ -83,6 +83,12 @@ class EArray(basearray.Array):
phys_sharding = self._data.sharding
return sharding_impls.logical_sharding(self.aval, phys_sharding)
@property
def device(self):
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
return self._data.device
return self.sharding
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
def addressable_data(self, index: int) -> EArray:

View File

@ -32,6 +32,7 @@ import numpy as np
import jax
from jax import lax
from jax.sharding import Sharding
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
"""
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return api.device_put(arr, device)
def _nbytes(arr: ArrayLike) -> int:
"""Total bytes consumed by the elements of the array."""
@ -694,6 +701,7 @@ _array_methods = {
"sum": reductions.sum,
"swapaxes": lax_numpy.swapaxes,
"take": lax_numpy.take,
"to_device": _to_device,
"trace": lax_numpy.trace,
"transpose": _transpose,
"var": reductions.var,

View File

@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
return jax.experimental.array_api
def _to_device(self, device: xe.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return jax.device_put(self, device)
def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
setattr(ArrayImpl, "device", property(lambda self: self.sharding))

View File

@ -1276,6 +1276,30 @@ class ShardingTest(jtu.JaxTestCase):
self.assertEqual(x1, x2)
self.assertEqual(hash(x1), hash(x2))
def test_device_attr(self):
# For single-device arrays, x.device returns the device
x = jnp.ones((2, 10))
self.assertEqual(x.device, list(x.devices())[0])
# For sharded arrays, x.device returns the sharding
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jax.device_put(x, sharding)
self.assertEqual(x.device, sharding)
def test_to_device(self):
device = jax.devices()[-1]
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jnp.ones((2, 10))
x_device = x.to_device(device)
x_sharding = x.to_device(sharding)
self.assertEqual(x_device.device, device)
self.assertEqual(x_sharding.device, sharding)
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected