[typing] add missing attributes to jax.Array

This commit is contained in:
Jake VanderPlas 2023-01-11 10:00:59 -08:00
parent ced7332587
commit e1738af5b2

View File

@ -17,12 +17,14 @@ import numpy as np
from jax._src.sharding import Sharding
from jax._src.array import Shard
from jax._src.typing import ArrayLike
class Array(abc.ABC):
dtype: np.dtype
ndim: int
size: int
itemsize: int
aval: Any
@property
@ -112,6 +114,9 @@ class Array(abc.ABC):
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def astype(self, dtype) -> Array: ...
def broadcast(self, sizes: Sequence[int]) -> Array: ...
def broadcast_in_dim(self, shape: Sequence[Union[int, Any]],
broadcast_dimensions: Sequence[int]) -> Array: ...
def choose(self, choices, out=None, mode='raise') -> Array: ...
def clip(self, min=None, max=None, out=None) -> Array: ...
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
@ -150,6 +155,7 @@ class Array(abc.ABC):
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ...
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...