mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jnp.take: fix annotation for fill_value
This commit is contained in:
parent
5d5ce1c919
commit
0ff0d7b95d
@ -115,6 +115,14 @@ class Array(abc.ABC):
|
||||
|
||||
Array.__module__ = "jax"
|
||||
|
||||
# StaticScalar is the Union of all scalar types that can be converted to
|
||||
# JAX arrays, and are possible to mark as static arguments.
|
||||
StaticScalar = Union[
|
||||
np.bool_, np.number, # NumPy scalar types
|
||||
bool, int, float, complex, # Python scalar types
|
||||
]
|
||||
StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars."
|
||||
|
||||
|
||||
# ArrayLike is a Union of all objects that can be implicitly converted to a
|
||||
# standard JAX array (i.e. not including future non-standard array types like
|
||||
@ -123,7 +131,6 @@ Array.__module__ = "jax"
|
||||
ArrayLike = Union[
|
||||
Array, # JAX array type
|
||||
np.ndarray, # NumPy array type
|
||||
np.bool_, np.number, # NumPy scalar types
|
||||
bool, int, float, complex, # Python scalar types
|
||||
StaticScalar, # valid scalars
|
||||
]
|
||||
ArrayLike.__doc__ = "Type annotation for JAX array-like objects."
|
||||
|
@ -217,11 +217,15 @@ class Array(abc.ABC):
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
||||
|
||||
StaticScalar = Union[
|
||||
np.bool_, np.number, # NumPy scalar types
|
||||
bool, int, float, complex, # Python scalar types
|
||||
]
|
||||
|
||||
ArrayLike = Union[
|
||||
Array, # JAX array type
|
||||
np.ndarray, # NumPy array type
|
||||
np.bool_, np.number, # NumPy scalar types
|
||||
bool, int, float, complex, # Python scalar types
|
||||
StaticScalar, # valid scalars
|
||||
]
|
||||
|
||||
|
||||
|
@ -68,8 +68,8 @@ from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike, DimSize, DuckTypedArray,
|
||||
DType, DTypeLike, Shape, DeprecatedArg
|
||||
Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray,
|
||||
DType, DTypeLike, Shape, StaticScalar,
|
||||
)
|
||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||
ceil_of_ratio, partition_list,
|
||||
@ -5637,7 +5637,7 @@ def take(
|
||||
mode: str | None = None,
|
||||
unique_indices: bool = False,
|
||||
indices_are_sorted: bool = False,
|
||||
fill_value: ArrayLike | None = None,
|
||||
fill_value: StaticScalar | None = None,
|
||||
) -> Array:
|
||||
return _take(a, indices, None if axis is None else operator.index(axis), out,
|
||||
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
|
@ -34,6 +34,7 @@ import enum
|
||||
from jax._src.basearray import (
|
||||
Array as Array,
|
||||
ArrayLike as ArrayLike,
|
||||
StaticScalar as StaticScalar,
|
||||
)
|
||||
|
||||
DType = np.dtype
|
||||
|
@ -11,8 +11,8 @@ from jax._src.lax.slicing import GatherScatterMode
|
||||
from jax._src.lib import Device
|
||||
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike, DType, DTypeLike,
|
||||
DimSize, DuckTypedArray, Shape, DeprecatedArg
|
||||
Array, ArrayLike, DType, DTypeLike, DeprecatedArg,
|
||||
DimSize, DuckTypedArray, Shape, StaticScalar,
|
||||
)
|
||||
from jax.numpy import fft as fft, linalg as linalg
|
||||
from jax.sharding import Sharding as _Sharding
|
||||
@ -804,7 +804,7 @@ def take(
|
||||
mode: Optional[str] = ...,
|
||||
unique_indices: builtins.bool = ...,
|
||||
indices_are_sorted: builtins.bool = ...,
|
||||
fill_value: Optional[ArrayLike] = ...,
|
||||
fill_value: Optional[StaticScalar] = ...,
|
||||
) -> Array: ...
|
||||
def take_along_axis(
|
||||
arr: ArrayLike,
|
||||
|
Loading…
x
Reference in New Issue
Block a user