mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22597 from jakevdp:arr-device
PiperOrigin-RevId: 655238275
This commit is contained in:
commit
dc42ba0e41
@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
will be removed in a future release.
|
||||
* Updated the repr of gpu devices to be more consistent
|
||||
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
|
||||
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
|
||||
part of JAX's [Array API](https://data-apis.org/array-api) support.
|
||||
* Deprecations
|
||||
* Removed a number of previously-deprecated internal APIs related to
|
||||
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
|
||||
|
@ -254,6 +254,13 @@ class ArrayImpl(basearray.Array):
|
||||
def sharding(self):
|
||||
return self._sharding
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
self._check_if_deleted()
|
||||
if isinstance(self.sharding, SingleDeviceSharding):
|
||||
return list(self.sharding.device_set)[0]
|
||||
return self.sharding
|
||||
|
||||
@property
|
||||
def weak_type(self):
|
||||
return self.aval.weak_type
|
||||
|
@ -22,6 +22,7 @@ from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
# TODO(jakevdp): fix import cycles and define these.
|
||||
Device = Any
|
||||
Shard = Any
|
||||
Sharding = Any
|
||||
|
||||
@ -112,6 +113,15 @@ class Array(abc.ABC):
|
||||
def sharding(self) -> Sharding:
|
||||
"""The sharding for the array."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def device(self) -> Device | Sharding:
|
||||
"""Array API-compatible device attribute.
|
||||
|
||||
For single-device arrays, this returns a Device. For sharded arrays, this
|
||||
returns a Sharding.
|
||||
"""
|
||||
|
||||
|
||||
Array.__module__ = "jax"
|
||||
|
||||
|
@ -204,6 +204,8 @@ class Array(abc.ABC):
|
||||
@property
|
||||
def sharding(self) -> Sharding: ...
|
||||
@property
|
||||
def device(self) -> Device | Sharding: ...
|
||||
@property
|
||||
def addressable_shards(self) -> Sequence[Shard]: ...
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]: ...
|
||||
@ -216,6 +218,7 @@ class Array(abc.ABC):
|
||||
@property
|
||||
def traceback(self) -> Traceback: ...
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
|
||||
|
||||
|
||||
StaticScalar = Union[
|
||||
|
@ -738,6 +738,15 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
||||
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
||||
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
||||
raise AttributeError(self,
|
||||
f"The 'device' attribute is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
@property
|
||||
def addressable_shards(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
|
@ -83,6 +83,12 @@ class EArray(basearray.Array):
|
||||
phys_sharding = self._data.sharding
|
||||
return sharding_impls.logical_sharding(self.aval, phys_sharding)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
|
||||
return self._data.device
|
||||
return self.sharding
|
||||
|
||||
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
||||
|
||||
def addressable_data(self, index: int) -> EArray:
|
||||
|
@ -32,6 +32,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.sharding import Sharding
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
|
||||
"""
|
||||
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
|
||||
|
||||
def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
|
||||
stream: int | Any | None = None):
|
||||
if stream is not None:
|
||||
raise NotImplementedError("stream argument of array.to_device()")
|
||||
return api.device_put(arr, device)
|
||||
|
||||
|
||||
def _nbytes(arr: ArrayLike) -> int:
|
||||
"""Total bytes consumed by the elements of the array."""
|
||||
@ -694,6 +701,7 @@ _array_methods = {
|
||||
"sum": reductions.sum,
|
||||
"swapaxes": lax_numpy.swapaxes,
|
||||
"take": lax_numpy.take,
|
||||
"to_device": _to_device,
|
||||
"trace": lax_numpy.trace,
|
||||
"transpose": _transpose,
|
||||
"var": reductions.var,
|
||||
|
@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
|
||||
return jax.experimental.array_api
|
||||
|
||||
|
||||
def _to_device(self, device: xe.Device | Sharding | None, *,
|
||||
stream: int | Any | None = None):
|
||||
if stream is not None:
|
||||
raise NotImplementedError("stream argument of array.to_device()")
|
||||
return jax.device_put(self, device)
|
||||
|
||||
|
||||
def add_array_object_methods():
|
||||
# TODO(jakevdp): set on tracers as well?
|
||||
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
|
||||
setattr(ArrayImpl, "to_device", _to_device)
|
||||
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
|
||||
|
@ -1276,6 +1276,30 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(x1, x2)
|
||||
self.assertEqual(hash(x1), hash(x2))
|
||||
|
||||
def test_device_attr(self):
|
||||
# For single-device arrays, x.device returns the device
|
||||
x = jnp.ones((2, 10))
|
||||
self.assertEqual(x.device, list(x.devices())[0])
|
||||
|
||||
# For sharded arrays, x.device returns the sharding
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
x = jax.device_put(x, sharding)
|
||||
self.assertEqual(x.device, sharding)
|
||||
|
||||
def test_to_device(self):
|
||||
device = jax.devices()[-1]
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
x = jnp.ones((2, 10))
|
||||
|
||||
x_device = x.to_device(device)
|
||||
x_sharding = x.to_device(sharding)
|
||||
|
||||
self.assertEqual(x_device.device, device)
|
||||
self.assertEqual(x_sharding.device, sharding)
|
||||
|
||||
|
||||
class RngShardingTest(jtu.JaxTestCase):
|
||||
# tests that the PRNGs are automatically sharded as expected
|
||||
|
Loading…
x
Reference in New Issue
Block a user