mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14394 from jakevdp:jax-array-methods
PiperOrigin-RevId: 508694486
This commit is contained in:
commit
7a864d73bc
@ -16,9 +16,12 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.array import Shard
|
||||
from jax._src.array import Device, Shard
|
||||
from jax._src.typing import ArrayLike
|
||||
|
||||
# TODO: alias this to xla_client.Traceback
|
||||
Traceback = Any
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
@ -186,3 +189,22 @@ class Array(abc.ABC):
|
||||
def at(self) -> Any: ...
|
||||
@property
|
||||
def weak_type(self) -> bool: ...
|
||||
|
||||
# Methods defined on ArrayImpl, but not on Tracers
|
||||
def addressable_data(self, index: int) -> Array: ...
|
||||
def block_until_ready(self) -> Array: ...
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
def delete(self) -> None: ...
|
||||
def device(self) -> Device: ...
|
||||
def devices(self) -> List[Device]: ...
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]: ...
|
||||
def is_deleted(self) -> bool: ...
|
||||
@property
|
||||
def is_fully_addressable(self) -> bool: ...
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool: ...
|
||||
def on_device_size_in_bytes(self) -> int: ...
|
||||
@property
|
||||
def traceback(self) -> Traceback: ...
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
@ -713,6 +713,67 @@ class Tracer(typing.Array):
|
||||
def _origin_msg(self) -> str:
|
||||
return ""
|
||||
|
||||
# Methods that are only valid for materialized arrays
|
||||
def addressable_data(self, index):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The addressable_data() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def block_until_ready(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'block_until_ready' method is not available on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def copy_to_host_async(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}")
|
||||
|
||||
def delete(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The delete() method was called on the JAX Tracer object {self}")
|
||||
|
||||
def device(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The device() method was called on the JAX Tracer object {self}")
|
||||
|
||||
def devices(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The devices() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def global_shards(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The global_shards property was called on the JAX Tracer object {self}")
|
||||
|
||||
def is_deleted(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_deleted() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def is_fully_addressable(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_fully_addressable property was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The is_fully_replicated property was called on the JAX Tracer object {self}")
|
||||
|
||||
def on_device_size_in_bytes(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The on_device_size_in_bytes() method was called on the JAX Tracer object {self}")
|
||||
|
||||
@property
|
||||
def traceback(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The traceback property was called on the JAX Tracer object {self}")
|
||||
|
||||
def unsafe_buffer_pointer(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The unsafe_buffer_pointer() method was called on the JAX Tracer object {self}")
|
||||
|
||||
# these can be used to set up forwarding of properties and instance methods from
|
||||
# Tracer instances to the underlying avals
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user