mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13032 from jakevdp:sharding-attr
PiperOrigin-RevId: 487061046
This commit is contained in:
commit
5e1d7cd52e
@ -13,9 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
# TODO(jakevdp) make this equal jax._src.sharding.Sharding
|
||||
Sharding = Any
|
||||
# TODO(jakevdp) make this equal jax._src.array.Shard
|
||||
Shard = Any
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
@ -26,6 +31,12 @@ class Array(abc.ABC):
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: ...
|
||||
|
||||
@property
|
||||
def sharding(self) -> Sharding: ...
|
||||
|
||||
@property
|
||||
def addressable_shards(self) -> Sequence[Shard]: ...
|
||||
|
||||
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
|
||||
order=None):
|
||||
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
||||
|
13
jax/core.py
13
jax/core.py
@ -559,6 +559,19 @@ class Tracer(typing.Array):
|
||||
def __len__(self):
|
||||
return self.aval._len(self)
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
||||
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
||||
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
||||
raise AttributeError(self,
|
||||
f"The 'sharding' attribute is not available on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def addressable_shards(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The 'addressable_shards' attribute is not available on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def at(self):
|
||||
return self.aval.at.fget(self)
|
||||
|
Loading…
x
Reference in New Issue
Block a user