[typing] define additional methods & properties on jax.Array

These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
This commit is contained in:
Jake VanderPlas 2023-02-10 09:42:32 -08:00
parent 9f0783f35d
commit 60256df668
2 changed files with 84 additions and 1 deletions

View File

@ -16,9 +16,12 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
from jax._src.sharding import Sharding
from jax._src.array import Shard
from jax._src.array import Device, Shard
from jax._src.typing import ArrayLike
# TODO: alias this to xla_client.Traceback
Traceback = Any
class Array(abc.ABC):
dtype: np.dtype
@ -186,3 +189,22 @@ class Array(abc.ABC):
def at(self) -> Any: ...
@property
def weak_type(self) -> bool: ...
# Methods defined on ArrayImpl, but not on Tracers
def addressable_data(self, index: int) -> Array: ...
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> List[Device]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def is_fully_replicated(self) -> bool: ...
def on_device_size_in_bytes(self) -> int: ...
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...

View File

@ -713,6 +713,67 @@ class Tracer(typing.Array):
def _origin_msg(self) -> str:
return ""
# Methods that are only valid for materialized arrays
def addressable_data(self, index):
raise ConcretizationTypeError(self,
f"The addressable_data() method was called on the JAX Tracer object {self}")
@property
def block_until_ready(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'block_until_ready' method is not available on the JAX Tracer object {self}")
@property
def copy_to_host_async(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}")
def delete(self):
raise ConcretizationTypeError(self,
f"The delete() method was called on the JAX Tracer object {self}")
def device(self):
raise ConcretizationTypeError(self,
f"The device() method was called on the JAX Tracer object {self}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on the JAX Tracer object {self}")
@property
def global_shards(self):
raise ConcretizationTypeError(self,
f"The global_shards property was called on the JAX Tracer object {self}")
def is_deleted(self):
raise ConcretizationTypeError(self,
f"The is_deleted() method was called on the JAX Tracer object {self}")
@property
def is_fully_addressable(self):
raise ConcretizationTypeError(self,
f"The is_fully_addressable property was called on the JAX Tracer object {self}")
@property
def is_fully_replicated(self):
raise ConcretizationTypeError(self,
f"The is_fully_replicated property was called on the JAX Tracer object {self}")
def on_device_size_in_bytes(self):
raise ConcretizationTypeError(self,
f"The on_device_size_in_bytes() method was called on the JAX Tracer object {self}")
@property
def traceback(self):
raise ConcretizationTypeError(self,
f"The traceback property was called on the JAX Tracer object {self}")
def unsafe_buffer_pointer(self):
raise ConcretizationTypeError(self,
f"The unsafe_buffer_pointer() method was called on the JAX Tracer object {self}")
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])