mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[typing] add missing attributes to jax.Array
This commit is contained in:
parent
ced7332587
commit
e1738af5b2
@ -17,12 +17,14 @@ import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.array import Shard
|
||||
from jax._src.typing import ArrayLike
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
ndim: int
|
||||
size: int
|
||||
itemsize: int
|
||||
aval: Any
|
||||
|
||||
@property
|
||||
@ -112,6 +114,9 @@ class Array(abc.ABC):
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def astype(self, dtype) -> Array: ...
|
||||
def broadcast(self, sizes: Sequence[int]) -> Array: ...
|
||||
def broadcast_in_dim(self, shape: Sequence[Union[int, Any]],
|
||||
broadcast_dimensions: Sequence[int]) -> Array: ...
|
||||
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
||||
def clip(self, min=None, max=None, out=None) -> Array: ...
|
||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
||||
@ -150,6 +155,7 @@ class Array(abc.ABC):
|
||||
def round(self, decimals=0, out=None) -> Array: ...
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ...
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
|
Loading…
x
Reference in New Issue
Block a user