mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[typing] define additional methods & properties on jax.Array
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers) In order to simplify the experience for users, we want to maintain only a single jax.Array type, so we define all methods here and raise explicit errors on Tracer instances.
This commit is contained in:
parent
9f0783f35d
commit
60256df668
@ -16,9 +16,12 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.array import Shard
|
||||
from jax._src.array import Device, Shard
|
||||
from jax._src.typing import ArrayLike
|
||||
|
||||
# TODO: alias this to xla_client.Traceback
|
||||
Traceback = Any
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
@ -186,3 +189,22 @@ class Array(abc.ABC):
|
||||
def at(self) -> Any: ...
|
||||
@property
|
||||
def weak_type(self) -> bool: ...
|
||||
|
||||
# Methods defined on ArrayImpl, but not on Tracers
|
||||
def addressable_data(self, index: int) -> Array: ...
|
||||
def block_until_ready(self) -> Array: ...
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
def delete(self) -> None: ...
|
||||
def device(self) -> Device: ...
|
||||
def devices(self) -> List[Device]: ...
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]: ...
|
||||
def is_deleted(self) -> bool: ...
|
||||
@property
|
||||
def is_fully_addressable(self) -> bool: ...
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool: ...
|
||||
def on_device_size_in_bytes(self) -> int: ...
|
||||
@property
|
||||
def traceback(self) -> Traceback: ...
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
@ -713,6 +713,67 @@ class Tracer(typing.Array):
|
||||
def _origin_msg(self) -> str:
|
||||
return ""
|
||||
|
||||
# Methods that are only valid for materialized arrays
|
||||
def addressable_data(self, index):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The addressable_data() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def block_until_ready(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'block_until_ready' method is not available on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def copy_to_host_async(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}")
|
||||
|
||||
def delete(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The delete() method was called on the JAX Tracer object {self}")
|
||||
|
||||
def device(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The device() method was called on the JAX Tracer object {self}")
|
||||
|
||||
def devices(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The devices() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def global_shards(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The global_shards property was called on the JAX Tracer object {self}")
|
||||
|
||||
def is_deleted(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_deleted() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def is_fully_addressable(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_fully_addressable property was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_fully_replicated property was called on the JAX Tracer object {self}")
|
||||
|
||||
def on_device_size_in_bytes(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The on_device_size_in_bytes() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def traceback(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The traceback property was called on the JAX Tracer object {self}")
|
||||
|
||||
def unsafe_buffer_pointer(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The unsafe_buffer_pointer() method was called on the JAX Tracer object {self}")
|
||||
|
||||
# these can be used to set up forwarding of properties and instance methods from
|
||||
# Tracer instances to the underlying avals
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user