Merge pull request #13032 from jakevdp:sharding-attr

PiperOrigin-RevId: 487061046
This commit is contained in:
jax authors 2022-11-08 15:01:23 -08:00
commit 5e1d7cd52e
2 changed files with 25 additions and 1 deletions

View File

@ -13,9 +13,14 @@
# limitations under the License.
import abc
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
# TODO(jakevdp) make this equal jax._src.sharding.Sharding
Sharding = Any
# TODO(jakevdp) make this equal jax._src.array.Shard
Shard = Any
class Array(abc.ABC):
dtype: np.dtype
@ -26,6 +31,12 @@ class Array(abc.ABC):
@property
def shape(self) -> Tuple[int, ...]: ...
@property
def sharding(self) -> Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."

View File

@ -559,6 +559,19 @@ class Tracer(typing.Array):
def __len__(self):
return self.aval._len(self)
@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'sharding' attribute is not available on the JAX Tracer object {self}")
@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
f"The 'addressable_shards' attribute is not available on the JAX Tracer object {self}")
@property
def at(self):
return self.aval.at.fget(self)