mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[typing] fix a few array type declarations
This commit is contained in:
parent
d0df18a76b
commit
d44b0389dd
@ -26,12 +26,20 @@ Traceback = Any
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
ndim: int
|
||||
size: int
|
||||
itemsize: int
|
||||
aval: Any
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype: ...
|
||||
|
||||
@property
|
||||
def ndim(self) -> int: ...
|
||||
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
|
||||
@ -46,8 +54,7 @@ class Array(abc.ABC):
|
||||
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
||||
" Use jax.numpy.array, or jax.numpy.zeros instead.")
|
||||
|
||||
def __getitem__(self, key, indices_are_sorted=False,
|
||||
unique_indices=False) -> Array: ...
|
||||
def __getitem__(self, key) -> Array: ...
|
||||
def __setitem__(self, key, value) -> None: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> Any: ...
|
||||
|
@ -337,6 +337,9 @@ def _chunk_iter(x, size):
|
||||
if tail:
|
||||
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
|
||||
|
||||
def _getitem(self, item):
|
||||
return lax_numpy._rewriting_take(self, item)
|
||||
|
||||
# Syntactic sugar for scatter operations.
|
||||
class _IndexUpdateHelper:
|
||||
# Note: this docstring will appear as the docstring for the `at` property.
|
||||
@ -596,7 +599,7 @@ class _IndexUpdateRef:
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
_array_operators = {
|
||||
"getitem": lax_numpy._rewriting_take,
|
||||
"getitem": _getitem,
|
||||
"setitem": _unimplemented_setitem,
|
||||
"copy": _copy,
|
||||
"deepcopy": _deepcopy,
|
||||
|
Loading…
x
Reference in New Issue
Block a user