mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[array API] add device property & to_device method
This commit is contained in:
parent
13e42ad420
commit
613a00044c
@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
will be removed in a future release.
|
will be removed in a future release.
|
||||||
* Updated the repr of gpu devices to be more consistent
|
* Updated the repr of gpu devices to be more consistent
|
||||||
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
|
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
|
||||||
|
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
|
||||||
|
part of JAX's [Array API](https://data-apis.org/array-api) support.
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* Removed a number of previously-deprecated internal APIs related to
|
* Removed a number of previously-deprecated internal APIs related to
|
||||||
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
|
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
|
||||||
|
@ -254,6 +254,13 @@ class ArrayImpl(basearray.Array):
|
|||||||
def sharding(self):
|
def sharding(self):
|
||||||
return self._sharding
|
return self._sharding
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
self._check_if_deleted()
|
||||||
|
if isinstance(self.sharding, SingleDeviceSharding):
|
||||||
|
return list(self.sharding.device_set)[0]
|
||||||
|
return self.sharding
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weak_type(self):
|
def weak_type(self):
|
||||||
return self.aval.weak_type
|
return self.aval.weak_type
|
||||||
|
@ -22,6 +22,7 @@ from typing import Any, Union
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
# TODO(jakevdp): fix import cycles and define these.
|
# TODO(jakevdp): fix import cycles and define these.
|
||||||
|
Device = Any
|
||||||
Shard = Any
|
Shard = Any
|
||||||
Sharding = Any
|
Sharding = Any
|
||||||
|
|
||||||
@ -112,6 +113,15 @@ class Array(abc.ABC):
|
|||||||
def sharding(self) -> Sharding:
|
def sharding(self) -> Sharding:
|
||||||
"""The sharding for the array."""
|
"""The sharding for the array."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def device(self) -> Device | Sharding:
|
||||||
|
"""Array API-compatible device attribute.
|
||||||
|
|
||||||
|
For single-device arrays, this returns a Device. For sharded arrays, this
|
||||||
|
returns a Sharding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
Array.__module__ = "jax"
|
Array.__module__ = "jax"
|
||||||
|
|
||||||
|
@ -204,6 +204,8 @@ class Array(abc.ABC):
|
|||||||
@property
|
@property
|
||||||
def sharding(self) -> Sharding: ...
|
def sharding(self) -> Sharding: ...
|
||||||
@property
|
@property
|
||||||
|
def device(self) -> Device | Sharding: ...
|
||||||
|
@property
|
||||||
def addressable_shards(self) -> Sequence[Shard]: ...
|
def addressable_shards(self) -> Sequence[Shard]: ...
|
||||||
@property
|
@property
|
||||||
def global_shards(self) -> Sequence[Shard]: ...
|
def global_shards(self) -> Sequence[Shard]: ...
|
||||||
@ -216,6 +218,7 @@ class Array(abc.ABC):
|
|||||||
@property
|
@property
|
||||||
def traceback(self) -> Traceback: ...
|
def traceback(self) -> Traceback: ...
|
||||||
def unsafe_buffer_pointer(self) -> int: ...
|
def unsafe_buffer_pointer(self) -> int: ...
|
||||||
|
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
|
||||||
|
|
||||||
|
|
||||||
StaticScalar = Union[
|
StaticScalar = Union[
|
||||||
|
@ -738,6 +738,15 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
|||||||
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
||||||
f"{self._origin_msg()}")
|
f"{self._origin_msg()}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
||||||
|
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
||||||
|
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
||||||
|
raise AttributeError(self,
|
||||||
|
f"The 'device' attribute is not available on {self._error_repr()}."
|
||||||
|
f"{self._origin_msg()}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def addressable_shards(self):
|
def addressable_shards(self):
|
||||||
raise ConcretizationTypeError(self,
|
raise ConcretizationTypeError(self,
|
||||||
|
@ -83,6 +83,12 @@ class EArray(basearray.Array):
|
|||||||
phys_sharding = self._data.sharding
|
phys_sharding = self._data.sharding
|
||||||
return sharding_impls.logical_sharding(self.aval, phys_sharding)
|
return sharding_impls.logical_sharding(self.aval, phys_sharding)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
|
||||||
|
return self._data.device
|
||||||
|
return self.sharding
|
||||||
|
|
||||||
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
||||||
|
|
||||||
def addressable_data(self, index: int) -> EArray:
|
def addressable_data(self, index: int) -> EArray:
|
||||||
|
@ -32,6 +32,7 @@ import numpy as np
|
|||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.sharding import Sharding
|
from jax.sharding import Sharding
|
||||||
|
from jax._src import api
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src.api_util import _ensure_index_tuple
|
from jax._src.api_util import _ensure_index_tuple
|
||||||
@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
|
|||||||
"""
|
"""
|
||||||
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
|
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
|
||||||
|
|
||||||
|
def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
|
||||||
|
stream: int | Any | None = None):
|
||||||
|
if stream is not None:
|
||||||
|
raise NotImplementedError("stream argument of array.to_device()")
|
||||||
|
return api.device_put(arr, device)
|
||||||
|
|
||||||
|
|
||||||
def _nbytes(arr: ArrayLike) -> int:
|
def _nbytes(arr: ArrayLike) -> int:
|
||||||
"""Total bytes consumed by the elements of the array."""
|
"""Total bytes consumed by the elements of the array."""
|
||||||
@ -694,6 +701,7 @@ _array_methods = {
|
|||||||
"sum": reductions.sum,
|
"sum": reductions.sum,
|
||||||
"swapaxes": lax_numpy.swapaxes,
|
"swapaxes": lax_numpy.swapaxes,
|
||||||
"take": lax_numpy.take,
|
"take": lax_numpy.take,
|
||||||
|
"to_device": _to_device,
|
||||||
"trace": lax_numpy.trace,
|
"trace": lax_numpy.trace,
|
||||||
"transpose": _transpose,
|
"transpose": _transpose,
|
||||||
"var": reductions.var,
|
"var": reductions.var,
|
||||||
|
@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
|
|||||||
return jax.experimental.array_api
|
return jax.experimental.array_api
|
||||||
|
|
||||||
|
|
||||||
def _to_device(self, device: xe.Device | Sharding | None, *,
|
|
||||||
stream: int | Any | None = None):
|
|
||||||
if stream is not None:
|
|
||||||
raise NotImplementedError("stream argument of array.to_device()")
|
|
||||||
return jax.device_put(self, device)
|
|
||||||
|
|
||||||
|
|
||||||
def add_array_object_methods():
|
def add_array_object_methods():
|
||||||
# TODO(jakevdp): set on tracers as well?
|
# TODO(jakevdp): set on tracers as well?
|
||||||
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
|
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
|
||||||
setattr(ArrayImpl, "to_device", _to_device)
|
|
||||||
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
|
|
||||||
|
@ -1276,6 +1276,30 @@ class ShardingTest(jtu.JaxTestCase):
|
|||||||
self.assertEqual(x1, x2)
|
self.assertEqual(x1, x2)
|
||||||
self.assertEqual(hash(x1), hash(x2))
|
self.assertEqual(hash(x1), hash(x2))
|
||||||
|
|
||||||
|
def test_device_attr(self):
|
||||||
|
# For single-device arrays, x.device returns the device
|
||||||
|
x = jnp.ones((2, 10))
|
||||||
|
self.assertEqual(x.device, list(x.devices())[0])
|
||||||
|
|
||||||
|
# For sharded arrays, x.device returns the sharding
|
||||||
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||||
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||||
|
x = jax.device_put(x, sharding)
|
||||||
|
self.assertEqual(x.device, sharding)
|
||||||
|
|
||||||
|
def test_to_device(self):
|
||||||
|
device = jax.devices()[-1]
|
||||||
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||||
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||||
|
|
||||||
|
x = jnp.ones((2, 10))
|
||||||
|
|
||||||
|
x_device = x.to_device(device)
|
||||||
|
x_sharding = x.to_device(sharding)
|
||||||
|
|
||||||
|
self.assertEqual(x_device.device, device)
|
||||||
|
self.assertEqual(x_sharding.device, sharding)
|
||||||
|
|
||||||
|
|
||||||
class RngShardingTest(jtu.JaxTestCase):
|
class RngShardingTest(jtu.JaxTestCase):
|
||||||
# tests that the PRNGs are automatically sharded as expected
|
# tests that the PRNGs are automatically sharded as expected
|
||||||
|
Loading…
x
Reference in New Issue
Block a user