jnp.take: fix annotation for fill_value

This commit is contained in:
Jake VanderPlas 2024-05-25 14:20:55 -07:00
parent 5d5ce1c919
commit 0ff0d7b95d
5 changed files with 22 additions and 10 deletions

View File

@ -115,6 +115,14 @@ class Array(abc.ABC):
Array.__module__ = "jax"
# StaticScalar is the Union of all scalar types that can be converted to
# JAX arrays, and are possible to mark as static arguments.
StaticScalar = Union[
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
]
StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars."
# ArrayLike is a Union of all objects that can be implicitly converted to a
# standard JAX array (i.e. not including future non-standard array types like
@ -123,7 +131,6 @@ Array.__module__ = "jax"
ArrayLike = Union[
Array, # JAX array type
np.ndarray, # NumPy array type
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
StaticScalar, # valid scalars
]
ArrayLike.__doc__ = "Type annotation for JAX array-like objects."

View File

@ -217,11 +217,15 @@ class Array(abc.ABC):
def unsafe_buffer_pointer(self) -> int: ...
StaticScalar = Union[
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
]
ArrayLike = Union[
Array, # JAX array type
np.ndarray, # NumPy array type
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
StaticScalar, # valid scalars
]

View File

@ -68,8 +68,8 @@ from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, StaticScalar,
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
@ -5637,7 +5637,7 @@ def take(
mode: str | None = None,
unique_indices: bool = False,
indices_are_sorted: bool = False,
fill_value: ArrayLike | None = None,
fill_value: StaticScalar | None = None,
) -> Array:
return _take(a, indices, None if axis is None else operator.index(axis), out,
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,

View File

@ -34,6 +34,7 @@ import enum
from jax._src.basearray import (
Array as Array,
ArrayLike as ArrayLike,
StaticScalar as StaticScalar,
)
DType = np.dtype

View File

@ -11,8 +11,8 @@ from jax._src.lax.slicing import GatherScatterMode
from jax._src.lib import Device
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import (
Array, ArrayLike, DType, DTypeLike,
DimSize, DuckTypedArray, Shape, DeprecatedArg
Array, ArrayLike, DType, DTypeLike, DeprecatedArg,
DimSize, DuckTypedArray, Shape, StaticScalar,
)
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
@ -804,7 +804,7 @@ def take(
mode: Optional[str] = ...,
unique_indices: builtins.bool = ...,
indices_are_sorted: builtins.bool = ...,
fill_value: Optional[ArrayLike] = ...,
fill_value: Optional[StaticScalar] = ...,
) -> Array: ...
def take_along_axis(
arr: ArrayLike,