Merge pull request #14394 from jakevdp:jax-array-methods

PiperOrigin-RevId: 508694486
This commit is contained in:
jax authors 2023-02-10 10:27:14 -08:00
commit 7a864d73bc
2 changed files with 84 additions and 1 deletions

View File

@ -16,9 +16,12 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
from jax._src.sharding import Sharding
from jax._src.array import Shard
from jax._src.array import Device, Shard
from jax._src.typing import ArrayLike
# TODO: alias this to xla_client.Traceback
Traceback = Any
class Array(abc.ABC):
dtype: np.dtype
@ -186,3 +189,22 @@ class Array(abc.ABC):
def at(self) -> Any: ...
@property
def weak_type(self) -> bool: ...
# Methods defined on ArrayImpl, but not on Tracers
def addressable_data(self, index: int) -> Array: ...
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> List[Device]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def is_fully_replicated(self) -> bool: ...
def on_device_size_in_bytes(self) -> int: ...
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...

View File

@ -713,6 +713,67 @@ class Tracer(typing.Array):
def _origin_msg(self) -> str:
return ""
# Methods that are only valid for materialized arrays
def addressable_data(self, index):
raise ConcretizationTypeError(self,
f"The addressable_data() method was called on the JAX Tracer object {self}")
@property
def block_until_ready(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'block_until_ready' method is not available on the JAX Tracer object {self}")
@property
def copy_to_host_async(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}")
def delete(self):
raise ConcretizationTypeError(self,
f"The delete() method was called on the JAX Tracer object {self}")
def device(self):
raise ConcretizationTypeError(self,
f"The device() method was called on the JAX Tracer object {self}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on the JAX Tracer object {self}")
@property
def global_shards(self):
raise ConcretizationTypeError(self,
f"The global_shards property was called on the JAX Tracer object {self}")
def is_deleted(self):
raise ConcretizationTypeError(self,
f"The is_deleted() method was called on the JAX Tracer object {self}")
@property
def is_fully_addressable(self):
raise ConcretizationTypeError(self,
f"The is_fully_addressable property was called on the JAX Tracer object {self}")
@property
def is_fully_replicated(self):
raise ConcretizationTypeError(self,
f"The is_fully_replicated property was called on the JAX Tracer object {self}")
def on_device_size_in_bytes(self):
raise ConcretizationTypeError(self,
f"The on_device_size_in_bytes() method was called on the JAX Tracer object {self}")
@property
def traceback(self):
raise ConcretizationTypeError(self,
f"The traceback property was called on the JAX Tracer object {self}")
def unsafe_buffer_pointer(self):
raise ConcretizationTypeError(self,
f"The unsafe_buffer_pointer() method was called on the JAX Tracer object {self}")
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])