mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #13010 from jakevdp:annotate-lax-numpy
PiperOrigin-RevId: 484572239
This commit is contained in:
commit
f9e7629c3f
@ -47,7 +47,7 @@ class _IndexGrid(abc.ABC):
|
||||
sparse: bool
|
||||
op_name: str
|
||||
|
||||
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array:
|
||||
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Union[Array, List[Array]]:
|
||||
if isinstance(key, slice):
|
||||
return _make_1d_grid_from_slice(key, op_name=self.op_name)
|
||||
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key)
|
||||
|
@ -1938,7 +1938,7 @@ https://jax.readthedocs.io/en/latest/faq.html).
|
||||
|
||||
@_wraps(np.array, lax_description=_ARRAY_DOC)
|
||||
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
|
||||
order: str = "K", ndmin: int = 0) -> Array:
|
||||
order: Optional[str] = "K", ndmin: int = 0) -> Array:
|
||||
if order is not None and order != "K":
|
||||
raise NotImplementedError("Only implemented for order='K'")
|
||||
|
||||
@ -2012,7 +2012,7 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
|
||||
return out_array
|
||||
|
||||
|
||||
def _convert_to_array_if_dtype_fails(x):
|
||||
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
|
||||
try:
|
||||
dtypes.dtype(x)
|
||||
except TypeError:
|
||||
@ -2022,14 +2022,14 @@ def _convert_to_array_if_dtype_fails(x):
|
||||
|
||||
|
||||
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
|
||||
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
|
||||
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "asarray")
|
||||
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
|
||||
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
|
||||
|
||||
|
||||
@_wraps(np.copy, lax_description=_ARRAY_DOC)
|
||||
def copy(a, order=None):
|
||||
def copy(a: Any, order: Optional[str] = None) -> Array:
|
||||
return array(a, copy=True, order=order)
|
||||
|
||||
|
||||
@ -2116,13 +2116,13 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
|
||||
|
||||
@_wraps(np.array_equal)
|
||||
def array_equal(a1, a2, equal_nan=False):
|
||||
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
||||
try:
|
||||
a1, a2 = asarray(a1), asarray(a2)
|
||||
except Exception:
|
||||
return False
|
||||
return bool_(False)
|
||||
if shape(a1) != shape(a2):
|
||||
return False
|
||||
return bool_(False)
|
||||
eq = asarray(a1 == a2)
|
||||
if equal_nan:
|
||||
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
|
||||
@ -2130,16 +2130,16 @@ def array_equal(a1, a2, equal_nan=False):
|
||||
|
||||
|
||||
@_wraps(np.array_equiv)
|
||||
def array_equiv(a1, a2):
|
||||
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
try:
|
||||
a1, a2 = asarray(a1), asarray(a2)
|
||||
except Exception:
|
||||
return False
|
||||
return bool_(False)
|
||||
try:
|
||||
eq = equal(a1, a2)
|
||||
except ValueError:
|
||||
# shapes are not broadcastable
|
||||
return False
|
||||
return bool_(False)
|
||||
return all(eq)
|
||||
|
||||
|
||||
@ -2207,7 +2207,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s
|
||||
|
||||
|
||||
@_wraps(np.eye)
|
||||
def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0,
|
||||
def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0,
|
||||
dtype: Optional[DTypeLike] = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "eye")
|
||||
N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
|
||||
@ -2219,14 +2219,14 @@ def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0,
|
||||
|
||||
|
||||
@_wraps(np.identity)
|
||||
def identity(n: core.DimSize, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
def identity(n: DimSize, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "identity")
|
||||
return eye(n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.arange)
|
||||
def arange(start: core.DimSize, stop: Optional[core.DimSize] = None,
|
||||
step: Optional[core.DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
step: Optional[DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "arange")
|
||||
require = partial(core.concrete_or_error, None)
|
||||
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
||||
@ -2404,7 +2404,8 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
|
||||
|
||||
|
||||
@_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
|
||||
def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
|
||||
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
|
||||
indexing: str = 'xy') -> List[Array]:
|
||||
_check_arraylike("meshgrid", *xi)
|
||||
args = [asarray(x) for x in xi]
|
||||
if not copy:
|
||||
@ -2426,17 +2427,16 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
|
||||
|
||||
@_wraps(np.i0)
|
||||
@jit
|
||||
def i0(x):
|
||||
x_orig = x
|
||||
x, = _promote_args_inexact("i0", x)
|
||||
if not issubdtype(x.dtype, np.floating):
|
||||
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x_orig)}")
|
||||
x = lax.abs(x)
|
||||
return lax.mul(lax.exp(x), lax.bessel_i0e(x))
|
||||
def i0(x: ArrayLike) -> Array:
|
||||
x_arr, = _promote_args_inexact("i0", x)
|
||||
if not issubdtype(x_arr.dtype, np.floating):
|
||||
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}")
|
||||
x_arr = lax.abs(x_arr)
|
||||
return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))
|
||||
|
||||
|
||||
@_wraps(np.ix_)
|
||||
def ix_(*args):
|
||||
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
|
||||
_check_arraylike("ix", *args)
|
||||
n = len(args)
|
||||
output = []
|
||||
@ -2458,8 +2458,18 @@ def ix_(*args):
|
||||
return tuple(output)
|
||||
|
||||
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: Literal[False] = False) -> Array: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
*, sparse: Literal[True]) -> Tuple[Array, ...]: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ...
|
||||
@_wraps(np.indices)
|
||||
def indices(dimensions, dtype=int32, sparse=False):
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]:
|
||||
dimensions = tuple(
|
||||
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
|
||||
for d in dimensions)
|
||||
@ -2487,13 +2497,16 @@ will be repeated.
|
||||
|
||||
|
||||
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
|
||||
def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
|
||||
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
|
||||
total_repeat_length: Optional[int] = None) -> Array:
|
||||
_check_arraylike("repeat", a)
|
||||
core.is_special_dim_size(repeats) or _check_arraylike("repeat", repeats)
|
||||
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
else:
|
||||
a = asarray(a)
|
||||
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()")
|
||||
assert isinstance(axis, int) # to appease mypy
|
||||
@ -2512,28 +2525,28 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
|
||||
|
||||
# Fast path for when repeats is a scalar.
|
||||
if np.ndim(repeats) == 0 and ndim(a) != 0:
|
||||
input_shape = a.shape
|
||||
input_shape = shape(a)
|
||||
aux_axis = axis if axis < 0 else axis + 1
|
||||
a = expand_dims(a, aux_axis)
|
||||
reps = [1] * len(a.shape)
|
||||
reps: List[DimSize] = [1] * len(shape(a))
|
||||
reps[aux_axis] = repeats
|
||||
a = tile(a, reps)
|
||||
result_shape = list(input_shape)
|
||||
result_shape: List[DimSize] = list(input_shape)
|
||||
result_shape[axis] *= repeats
|
||||
return reshape(a, result_shape)
|
||||
|
||||
repeats = np.ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = np.broadcast_to(repeats, [a.shape[axis]])
|
||||
repeats = np.broadcast_to(repeats, [shape(a)[axis]])
|
||||
total_repeat_length = np.sum(repeats)
|
||||
else:
|
||||
repeats = ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = broadcast_to(repeats, [a.shape[axis]])
|
||||
repeats = broadcast_to(repeats, [shape(a)[axis]])
|
||||
|
||||
# Special case when a is a scalar.
|
||||
if ndim(a) == 0:
|
||||
if repeats.shape == (1,):
|
||||
if shape(repeats) == (1,):
|
||||
return full([total_repeat_length], a)
|
||||
else:
|
||||
raise ValueError('`repeat` with a scalar parameter `a` is only '
|
||||
@ -2541,13 +2554,13 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
|
||||
|
||||
# Special case if total_repeat_length is zero.
|
||||
if total_repeat_length == 0:
|
||||
result_shape = list(a.shape)
|
||||
result_shape = list(shape(a))
|
||||
result_shape[axis] = 0
|
||||
return reshape(array([], dtype=a.dtype), result_shape)
|
||||
return reshape(array([], dtype=_dtype(a)), result_shape)
|
||||
|
||||
# If repeats is on a zero sized axis, then return the array.
|
||||
if a.shape[axis] == 0:
|
||||
return a
|
||||
if shape(a)[axis] == 0:
|
||||
return asarray(a)
|
||||
|
||||
# This implementation of repeat avoid having to instantiate a large.
|
||||
# intermediate tensor.
|
||||
@ -2565,7 +2578,7 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
|
||||
|
||||
|
||||
@_wraps(np.tri)
|
||||
def tri(N, M=None, k=0, dtype=None):
|
||||
def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "tri")
|
||||
M = M if M is not None else N
|
||||
dtype = dtype or float32
|
||||
@ -2574,29 +2587,32 @@ def tri(N, M=None, k=0, dtype=None):
|
||||
|
||||
@_wraps(np.tril)
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def tril(m, k=0):
|
||||
def tril(m: ArrayLike, k: int = 0) -> Array:
|
||||
_check_arraylike("tril", m)
|
||||
m_shape = shape(m)
|
||||
if len(m_shape) < 2:
|
||||
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
|
||||
mask = tri(*m_shape[-2:], k=k, dtype=bool)
|
||||
N, M = m_shape[-2:]
|
||||
mask = tri(N, M, k=k, dtype=bool)
|
||||
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
|
||||
|
||||
|
||||
@_wraps(np.triu, update_doc=False)
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def triu(m, k=0):
|
||||
def triu(m: ArrayLike, k: int = 0) -> Array:
|
||||
_check_arraylike("triu", m)
|
||||
m_shape = shape(m)
|
||||
if len(m_shape) < 2:
|
||||
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
|
||||
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
|
||||
N, M = m_shape[-2:]
|
||||
mask = tri(N, M, k=k - 1, dtype=bool)
|
||||
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
|
||||
|
||||
|
||||
@_wraps(np.trace, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype'))
|
||||
def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
|
||||
def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
|
||||
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
|
||||
_check_arraylike("trace", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
|
||||
@ -2636,13 +2652,15 @@ mask_indices = _wrap_indices_function(np.mask_indices)
|
||||
|
||||
|
||||
@_wraps(np.triu_indices_from)
|
||||
def triu_indices_from(arr, k=0):
|
||||
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
|
||||
def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]:
|
||||
arr_shape = shape(arr)
|
||||
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
|
||||
|
||||
@_wraps(np.tril_indices_from)
|
||||
def tril_indices_from(arr, k=0):
|
||||
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
|
||||
def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]:
|
||||
arr_shape = shape(arr)
|
||||
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
|
||||
|
||||
@_wraps(np.diag_indices)
|
||||
@ -4250,7 +4268,7 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
|
||||
idx = tuple(idx) + colons
|
||||
return idx
|
||||
|
||||
def _static_idx(idx: slice, size: core.DimSize):
|
||||
def _static_idx(idx: slice, size: DimSize):
|
||||
"""Helper function to compute the static slice start/limit/stride values."""
|
||||
if isinstance(size, int):
|
||||
start, stop, step = idx.indices(size)
|
||||
@ -4318,11 +4336,11 @@ def kaiser(M: int, beta: ArrayLike) -> Array:
|
||||
return i0(beta * sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
|
||||
|
||||
|
||||
def _gcd_cond_fn(xs):
|
||||
def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array:
|
||||
x1, x2 = xs
|
||||
return any(x2 != 0)
|
||||
|
||||
def _gcd_body_fn(xs):
|
||||
def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
|
||||
x1, x2 = xs
|
||||
x1, x2 = (where(x2 != 0, x2, x1),
|
||||
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
|
||||
@ -4330,7 +4348,7 @@ def _gcd_body_fn(xs):
|
||||
|
||||
@_wraps(np.gcd, module='numpy')
|
||||
@jit
|
||||
def gcd(x1, x2):
|
||||
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
_check_arraylike("gcd", x1, x2)
|
||||
x1, x2 = _promote_dtypes(x1, x2)
|
||||
if not issubdtype(_dtype(x1), integer):
|
||||
@ -4342,7 +4360,7 @@ def gcd(x1, x2):
|
||||
|
||||
@_wraps(np.lcm, module='numpy')
|
||||
@jit
|
||||
def lcm(x1, x2):
|
||||
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
_check_arraylike("lcm", x1, x2)
|
||||
x1, x2 = _promote_dtypes(x1, x2)
|
||||
if not issubdtype(_dtype(x1), integer):
|
||||
@ -4353,34 +4371,37 @@ def lcm(x1, x2):
|
||||
|
||||
|
||||
@_wraps(np.extract)
|
||||
def extract(condition, arr):
|
||||
def extract(condition: ArrayLike, arr: ArrayLike) -> Array:
|
||||
return compress(ravel(condition), ravel(arr))
|
||||
|
||||
|
||||
@_wraps(np.compress, skip_params=['out'])
|
||||
def compress(condition, a, axis: Optional[int] = None, out=None):
|
||||
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None,
|
||||
out: None = None) -> Array:
|
||||
_check_arraylike("compress", condition, a)
|
||||
condition_arr = asarray(condition).astype(bool)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
|
||||
if ndim(condition) != 1:
|
||||
if condition_arr.ndim != 1:
|
||||
raise ValueError("condition must be a 1D array")
|
||||
condition = asarray(condition).astype(bool)
|
||||
if axis is None:
|
||||
axis = 0
|
||||
a = ravel(a)
|
||||
arr = ravel(a)
|
||||
else:
|
||||
a = moveaxis(a, axis, 0)
|
||||
condition, extra = condition[:a.shape[0]], condition[a.shape[0]:]
|
||||
arr = moveaxis(a, axis, 0)
|
||||
condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:]
|
||||
if any(extra):
|
||||
raise ValueError("condition contains entries that are out of bounds")
|
||||
a = a[:condition.shape[0]]
|
||||
return moveaxis(a[condition], 0, axis)
|
||||
arr = arr[:condition_arr.shape[0]]
|
||||
return moveaxis(arr[condition_arr], 0, axis)
|
||||
|
||||
|
||||
@_wraps(np.cov)
|
||||
@partial(jit, static_argnames=('rowvar', 'bias', 'ddof'))
|
||||
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
aweights=None):
|
||||
def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
|
||||
bias: bool = False, ddof: Optional[int] = None,
|
||||
fweights: Optional[ArrayLike] = None,
|
||||
aweights: Optional[ArrayLike] = None) -> Array:
|
||||
if y is not None:
|
||||
m, y = _promote_args_inexact("cov", m, y)
|
||||
if y.ndim > 2:
|
||||
@ -4398,14 +4419,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
return array([]).reshape(0, 0)
|
||||
|
||||
if y is not None:
|
||||
y = atleast_2d(y)
|
||||
if not rowvar and y.shape[0] != 1:
|
||||
y = y.T
|
||||
X = concatenate((X, y), axis=0)
|
||||
y_arr = atleast_2d(y)
|
||||
if not rowvar and y_arr.shape[0] != 1:
|
||||
y_arr = y_arr.T
|
||||
X = concatenate((X, y_arr), axis=0)
|
||||
if ddof is None:
|
||||
ddof = 1 if bias == 0 else 0
|
||||
|
||||
w = None
|
||||
w: Optional[Array] = None
|
||||
if fweights is not None:
|
||||
_check_arraylike("cov", fweights)
|
||||
if ndim(fweights) > 1:
|
||||
@ -4424,7 +4445,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
raise RuntimeError("incompatible numbers of samples and aweights")
|
||||
# Ensure positive aweights: note that numpy raises an error for negative aweights.
|
||||
aweights = abs(aweights)
|
||||
w = aweights if w is None else w * aweights
|
||||
w = asarray(aweights) if w is None else w * asarray(aweights)
|
||||
|
||||
avg, w_sum = average(X, axis=1, weights=w, returned=True)
|
||||
w_sum = w_sum[0]
|
||||
@ -4445,7 +4466,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
|
||||
@_wraps(np.corrcoef)
|
||||
@partial(jit, static_argnames=('rowvar',))
|
||||
def corrcoef(x, y=None, rowvar=True):
|
||||
def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array:
|
||||
_check_arraylike("corrcoef", x)
|
||||
c = cov(x, y, rowvar)
|
||||
if len(shape(c)) == 0:
|
||||
@ -4467,9 +4488,9 @@ def corrcoef(x, y=None, rowvar=True):
|
||||
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, method="linear", keepdims=False,
|
||||
interpolation=None):
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
_check_arraylike("quantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
|
||||
@ -4478,14 +4499,14 @@ def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(a, q, axis, interpolation or method, keepdims, False)
|
||||
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)
|
||||
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
_check_arraylike("nanquantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
|
||||
@ -4494,9 +4515,10 @@ def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(a, q, axis, interpolation or method, keepdims, True)
|
||||
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
|
||||
"'midpoint', or 'nearest'")
|
||||
@ -4524,7 +4546,6 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
|
||||
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
|
||||
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
|
||||
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
|
||||
keepdim = tuple(keepdim)
|
||||
axis = _canonicalize_axis(-1, ndim(a))
|
||||
else:
|
||||
axis = _canonicalize_axis(axis, ndim(a))
|
||||
@ -4614,13 +4635,13 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
|
||||
raise ValueError(f"interpolation={interpolation!r} not recognized")
|
||||
if keepdims and keepdim:
|
||||
if q_ndim > 0:
|
||||
keepdim = (shape(q)[0],) + keepdim
|
||||
keepdim = [shape(q)[0], *keepdim]
|
||||
result = reshape(result, keepdim)
|
||||
return lax.convert_element_type(result, a.dtype)
|
||||
|
||||
|
||||
@partial(vectorize, excluded={0, 2, 3})
|
||||
def _searchsorted_via_scan(sorted_arr, query, side, dtype):
|
||||
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
|
||||
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
|
||||
def body_fun(_, state):
|
||||
low, high = state
|
||||
@ -4632,7 +4653,7 @@ def _searchsorted_via_scan(sorted_arr, query, side, dtype):
|
||||
return lax.fori_loop(0, n_levels, body_fun, init)[1]
|
||||
|
||||
|
||||
def _searchsorted_via_sort(sorted_arr, query, side, dtype):
|
||||
def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
|
||||
working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64
|
||||
def _rank(x):
|
||||
idx = lax.iota(working_dtype, len(x))
|
||||
@ -4653,7 +4674,8 @@ def _searchsorted_via_sort(sorted_arr, query, side, dtype):
|
||||
'sort' is often more performant on accelerator backends like GPU and TPU (particularly
|
||||
when ``v`` is very large)."""))
|
||||
@partial(jit, static_argnames=('side', 'sorter', 'method'))
|
||||
def searchsorted(a, v, side='left', sorter=None, *, method='scan'):
|
||||
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
|
||||
sorter: None = None, *, method: str = 'scan') -> Array:
|
||||
_check_arraylike("searchsorted", a, v)
|
||||
if side not in ['left', 'right']:
|
||||
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
|
||||
@ -4670,22 +4692,23 @@ def searchsorted(a, v, side='left', sorter=None, *, method='scan'):
|
||||
if len(a) == 0:
|
||||
return zeros_like(v, dtype=dtype)
|
||||
impl = _searchsorted_via_scan if method == 'scan' else _searchsorted_via_sort
|
||||
return impl(a, v, side, dtype)
|
||||
return impl(asarray(a), asarray(v), side, dtype)
|
||||
|
||||
@_wraps(np.digitize)
|
||||
@partial(jit, static_argnames=('right',))
|
||||
def digitize(x, bins, right=False):
|
||||
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
|
||||
_check_arraylike("digitize", x, bins)
|
||||
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
|
||||
if ndim(bins) != 1:
|
||||
bins_arr = asarray(bins)
|
||||
if bins_arr.ndim != 1:
|
||||
raise ValueError(f"digitize: bins must be a 1-dimensional array; got bins={bins}")
|
||||
if len(bins) == 0:
|
||||
if bins_arr.shape[0] == 0:
|
||||
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
|
||||
side = 'right' if not right else 'left'
|
||||
return where(
|
||||
bins[-1] >= bins[0],
|
||||
searchsorted(bins, x, side=side),
|
||||
len(bins) - searchsorted(bins[::-1], x, side=side)
|
||||
bins_arr[-1] >= bins_arr[0],
|
||||
searchsorted(bins_arr, x, side=side),
|
||||
len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side)
|
||||
)
|
||||
|
||||
_PIECEWISE_DOC = """\
|
||||
@ -4695,9 +4718,10 @@ See the :func:`jax.lax.switch` documentation for more information.
|
||||
"""
|
||||
|
||||
@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
|
||||
def piecewise(x, condlist, funclist, *args, **kw):
|
||||
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
|
||||
funclist: List[Union[ArrayLike, Callable[..., Array]]],
|
||||
*args, **kw) -> Array:
|
||||
_check_arraylike("piecewise", x)
|
||||
condlist = array(condlist, dtype=bool_)
|
||||
nc, nf = len(condlist), len(funclist)
|
||||
if nf == nc + 1:
|
||||
funclist = funclist[-1:] + funclist[:-1]
|
||||
@ -4707,14 +4731,16 @@ def piecewise(x, condlist, funclist, *args, **kw):
|
||||
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
|
||||
consts = {i: c for i, c in enumerate(funclist) if not callable(c)}
|
||||
funcs = {i: f for i, f in enumerate(funclist) if callable(f)}
|
||||
return _piecewise(x, condlist, consts,
|
||||
return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts,
|
||||
frozenset(funcs.items()), # dict is not hashable.
|
||||
*args, **kw)
|
||||
|
||||
@partial(jit, static_argnames=['funcs'])
|
||||
def _piecewise(x, condlist, consts, funcs, *args, **kw):
|
||||
funcs = dict(funcs)
|
||||
funclist = [consts.get(i, funcs.get(i)) for i in range(len(condlist) + 1)]
|
||||
def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
|
||||
funcs: FrozenSet[Tuple[int, Callable[..., Array]]],
|
||||
*args, **kw) -> Array:
|
||||
funcdict = dict(funcs)
|
||||
funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)]
|
||||
indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0)
|
||||
dtype = _dtype(x)
|
||||
def _call(f):
|
||||
@ -4728,9 +4754,10 @@ def _piecewise(x, condlist, consts, funcs, *args, **kw):
|
||||
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
def percentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
_check_arraylike("percentile", a, q)
|
||||
q, = _promote_dtypes_inexact(q)
|
||||
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
@ -4739,9 +4766,10 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
_check_arraylike("nanpercentile", a, q)
|
||||
q = true_divide(q, float32(100.0))
|
||||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
@ -4750,23 +4778,25 @@ def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, keepdims=False):
|
||||
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
_check_arraylike("median", a)
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, method='midpoint')
|
||||
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, keepdims=False):
|
||||
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
_check_arraylike("nanmedian", a)
|
||||
return nanquantile(a, 0.5, axis=axis, out=out,
|
||||
overwrite_input=overwrite_input, keepdims=keepdims,
|
||||
method='midpoint')
|
||||
|
||||
|
||||
def _astype(arr, dtype):
|
||||
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
|
||||
"""Copy the array and cast to a specified dtype.
|
||||
|
||||
This is implemeted via :func:`jax.lax.convert_element_type`, which may
|
||||
@ -4780,19 +4810,21 @@ def _astype(arr, dtype):
|
||||
return lax.convert_element_type(arr, dtype)
|
||||
|
||||
|
||||
def _nbytes(arr):
|
||||
def _nbytes(arr: ArrayLike) -> int:
|
||||
return size(arr) * _dtype(arr).itemsize
|
||||
|
||||
|
||||
def _itemsize(arr):
|
||||
def _itemsize(arr: ArrayLike) -> int:
|
||||
return _dtype(arr).itemsize
|
||||
|
||||
|
||||
def _clip(number, min=None, max=None, out=None): # noqa: F811
|
||||
def _clip(number: ArrayLike,
|
||||
min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, # noqa: F811
|
||||
out: None = None) -> Array:
|
||||
return clip(number, a_min=min, a_max=max, out=out)
|
||||
|
||||
|
||||
def _view(arr, dtype=None, type=None):
|
||||
def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "view")
|
||||
if type is not None:
|
||||
raise NotImplementedError("`type` argument of array.view()")
|
||||
@ -4814,7 +4846,8 @@ def _view(arr, dtype=None, type=None):
|
||||
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
|
||||
raise ValueError("When changing to a larger dtype, its size must be a divisor "
|
||||
"of the total size in bytes of the last axis of the array.")
|
||||
byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64}
|
||||
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
|
||||
32: np.dtype('uint32'), 64: np.dtype('uint64')}
|
||||
if nbits_in not in byte_dtypes:
|
||||
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
|
||||
if nbits_out not in byte_dtypes:
|
||||
@ -4883,15 +4916,15 @@ def _unimplemented_setitem(self, i, x):
|
||||
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
|
||||
raise TypeError(msg.format(type(self)))
|
||||
|
||||
def _operator_round(number, ndigits=None):
|
||||
def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array:
|
||||
out = round(number, decimals=ndigits or 0)
|
||||
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
|
||||
return out.astype(int) if ndigits is None else out
|
||||
|
||||
def _copy(self):
|
||||
def _copy(self: Array) -> Array:
|
||||
return self.copy()
|
||||
|
||||
def _deepcopy(self, memo):
|
||||
def _deepcopy(self: Array, memo: Any) -> Array:
|
||||
del memo # unused
|
||||
return self.copy()
|
||||
|
||||
@ -4970,22 +5003,23 @@ def __array_module__(self, types):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def _compress_method(a, condition, axis=None, out=None):
|
||||
def _compress_method(a: ArrayLike, condition: ArrayLike,
|
||||
axis: Optional[int] = None, out: None = None) -> Array:
|
||||
return compress(condition, a, axis, out)
|
||||
|
||||
|
||||
@core.stash_axis_env()
|
||||
@partial(jit, static_argnums=(1,2,3))
|
||||
def _multi_slice(arr,
|
||||
def _multi_slice(arr: ArrayLike,
|
||||
start_indices: Tuple[Tuple[int, ...]],
|
||||
limit_indices: Tuple[Tuple[int, ...]],
|
||||
removed_dims: Tuple[Tuple[int, ...]]):
|
||||
removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]:
|
||||
"""Extracts multiple slices from `arr`.
|
||||
|
||||
This is used to shard DeviceArray arguments to pmap. It's implemented as a
|
||||
DeviceArray method here to avoid circular imports.
|
||||
"""
|
||||
results = []
|
||||
results: List[Array] = []
|
||||
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
|
||||
sliced = lax.slice(arr, starts, limits)
|
||||
if removed:
|
||||
@ -4996,7 +5030,7 @@ def _multi_slice(arr,
|
||||
# The next two functions are related to iter(device_array), implemented here to
|
||||
# avoid circular imports.
|
||||
@jit
|
||||
def _unstack(x):
|
||||
def _unstack(x: Array) -> List[Array]:
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
setattr(device_array.DeviceArray, "_unstack", _unstack)
|
||||
setattr(ArrayImpl, '_unstack', _unstack)
|
||||
|
Loading…
x
Reference in New Issue
Block a user