Merge pull request #23398 from damianoamatruda:fix-pytype-array

PiperOrigin-RevId: 670937582
This commit is contained in:
jax authors 2024-09-04 05:45:16 -07:00
commit 22be4eafca
2 changed files with 22 additions and 21 deletions

View File

@ -131,15 +131,17 @@ class Array(abc.ABC):
# np.ndarray methods:
def all(self, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def any(self: Array, axis: Axis = None, out: None = None,
def any(self, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def argmax(self: Array, axis: int | None = None, out: None = None,
def argmax(self, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argmin(self, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argpartition(self, kth, axis=-1, kind='introselect', order: None = None) -> Array: ...
def argsort(self, axis: int | None = -1, kind='quicksort', order: None = None) -> Array: ...
def astype(self, dtype: DTypeLike | None = None, max: ArrayLike | None = None) -> Array: ...
def argpartition(self, kth: int, axis: int = -1) -> Array: ...
def argsort(self, axis: int | None = -1, *, kind: None = None, order: None = None,
stable: bool = True, descending: bool = False) -> Array: ...
def astype(self, dtype: DTypeLike | None = None, copy: bool = False,
device: Device | Sharding | None = None) -> Array: ...
def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ...
def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ...
def compress(self, condition: ArrayLike,
@ -148,10 +150,10 @@ class Array(abc.ABC):
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def cumsum(self, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def cumprod(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def cumsum(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array: ...
@ -176,7 +178,7 @@ class Array(abc.ABC):
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array: ...
def ptp(self, axis: Axis = None, out: None = None,
def ptp(self, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array: ...
def ravel(self, order: str = 'C') -> Array: ...
@property
@ -189,7 +191,7 @@ class Array(abc.ABC):
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ...
def sort(self, axis: int | None = -1, *, kind: None = None,
order: None = None, stable: bool = True, descending: bool = False) -> Array: ...
def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ...
def squeeze(self, axis: Axis = None) -> Array: ...
def std(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
@ -212,7 +214,7 @@ class Array(abc.ABC):
def var(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
def view(self, dtype=None, type=None) -> Array: ...
def view(self, dtype: DTypeLike | None = None, type: None = None) -> Array: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we

View File

@ -90,13 +90,12 @@ def _argmin(self: Array, axis: int | None = None, out: None = None,
"""
return lax_numpy.argmin(self, axis=axis, out=out, keepdims=keepdims)
def _argpartition(self: Array, kth: int, axis: int = -1,
kind: str = 'introselect', order: None = None) -> Array:
def _argpartition(self: Array, kth: int, axis: int = -1) -> Array:
"""Return the indices that partially sort the array.
Refer to :func:`jax.numpy.argpartition` for the full documentation.
"""
return lax_numpy.argpartition(self, kth=kth, axis=axis, kind=kind, order=order)
return lax_numpy.argpartition(self, kth=kth, axis=axis)
def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: None = None,
stable: bool = True, descending: bool = False) -> Array:
@ -123,7 +122,7 @@ def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: s
Refer to :func:`jax.numpy.choose` for the full documentation.
"""
return lax_numpy.choose(self, choices=choices)
return lax_numpy.choose(self, choices=choices, out=out, mode=mode)
def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.
@ -163,16 +162,16 @@ def _copy(self: Array) -> Array:
"""
return lax_numpy.copy(self)
def _cumprod(self: Array, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Return the cumulative product of the array.
Refer to :func:`jax.numpy.cumprod` for the full documentation.
"""
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)
def _cumsum(self: Array, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Return the cumulative sum of the array.
Refer to :func:`jax.numpy.cumsum` for the full documentation.
@ -337,7 +336,7 @@ def _sort(self: Array, axis: int | None = -1, *, kind: None = None,
return lax_numpy.sort(self, axis=axis, kind=kind, order=order,
stable=stable, descending=descending)
def _squeeze(self: Array, axis: int | Sequence[int] | None = None) -> Array:
def _squeeze(self: Array, axis: reductions.Axis = None) -> Array:
"""Remove one or more length-1 axes from array.
Refer to :func:`jax.numpy.squeeze` for full documentation.