[typing] fix a few array type declarations

This commit is contained in:
Jake VanderPlas 2023-09-12 13:21:48 -07:00
parent d0df18a76b
commit d44b0389dd
2 changed files with 17 additions and 7 deletions

View File

@ -26,12 +26,20 @@ Traceback = Any
class Array(abc.ABC):
dtype: np.dtype
ndim: int
size: int
itemsize: int
aval: Any
@property
def dtype(self) -> np.dtype: ...
@property
def ndim(self) -> int: ...
@property
def size(self) -> int: ...
@property
def itemsize(self) -> int: ...
@property
def shape(self) -> tuple[int, ...]: ...
@ -46,8 +54,7 @@ class Array(abc.ABC):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False) -> Array: ...
def __getitem__(self, key) -> Array: ...
def __setitem__(self, key, value) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Any: ...

View File

@ -337,6 +337,9 @@ def _chunk_iter(x, size):
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
def _getitem(self, item):
return lax_numpy._rewriting_take(self, item)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
# Note: this docstring will appear as the docstring for the `at` property.
@ -596,7 +599,7 @@ class _IndexUpdateRef:
unique_indices=unique_indices, mode=mode)
_array_operators = {
"getitem": lax_numpy._rewriting_take,
"getitem": _getitem,
"setitem": _unimplemented_setitem,
"copy": _copy,
"deepcopy": _deepcopy,