mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[typing] annotate next part of lax_numpy.py
This commit is contained in:
parent
cf6b5097d0
commit
2f27d516d7
@ -426,8 +426,9 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
|
||||
dtype = dtypes.to_inexact_dtype(arr.dtype)
|
||||
if _ndim(bins) == 1:
|
||||
return asarray(bins, dtype=dtype)
|
||||
bins = core.concrete_or_error(operator.index, bins,
|
||||
"bins argument of histogram_bin_edges")
|
||||
|
||||
bins_int = core.concrete_or_error(operator.index, bins,
|
||||
"bins argument of histogram_bin_edges")
|
||||
if range is None:
|
||||
range = [arr.min(), arr.max()]
|
||||
range = asarray(range, dtype=dtype)
|
||||
@ -436,7 +437,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
|
||||
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
|
||||
where(ptp(range) == 0, range[1] + 0.5, range[1]))
|
||||
assert range is not None
|
||||
return linspace(range[0], range[1], bins + 1, dtype=dtype)
|
||||
return linspace(range[0], range[1], bins_int + 1, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.histogram)
|
||||
@ -865,13 +866,13 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
|
||||
shape = [shape]
|
||||
if _any(ndim(s) != 0 for s in shape):
|
||||
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
|
||||
out_indices = [None] * len(shape)
|
||||
out_indices = [0] * len(shape)
|
||||
for i, s in reversed(list(enumerate(shape))):
|
||||
indices_arr, out_indices[i] = divmod(indices_arr, s)
|
||||
oob_pos = indices_arr > 0
|
||||
oob_neg = indices_arr < -1
|
||||
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
|
||||
for s, i in zip(shape, out_indices))
|
||||
for s, i in safe_zip(shape, out_indices))
|
||||
|
||||
@_wraps(np.resize)
|
||||
@partial(jit, static_argnames=('new_shape',))
|
||||
@ -986,41 +987,66 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike
|
||||
|
||||
@_wraps(np.interp)
|
||||
@jit
|
||||
def interp(x, xp, fp, left=None, right=None, period=None):
|
||||
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
left: Optional[ArrayLike] = None,
|
||||
right: Optional[ArrayLike] = None,
|
||||
period: Optional[ArrayLike] = None) -> Array:
|
||||
_check_arraylike("interp", x, xp, fp)
|
||||
if shape(xp) != shape(fp) or ndim(xp) != 1:
|
||||
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
|
||||
x, xp = _promote_dtypes_inexact(x, xp)
|
||||
fp, = _promote_dtypes_inexact(fp)
|
||||
x_arr, xp_arr = _promote_dtypes_inexact(x, xp)
|
||||
fp_arr, = _promote_dtypes_inexact(fp)
|
||||
del x, xp, fp
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
|
||||
raise ValueError("jnp.interp: complex x values not supported.")
|
||||
|
||||
if period is not None:
|
||||
if ndim(period) != 0:
|
||||
raise ValueError(f"period must be a scalar; got {period}")
|
||||
period = abs(period)
|
||||
x = x % period
|
||||
xp = xp % period
|
||||
xp, fp = lax.sort_key_val(xp, fp)
|
||||
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
|
||||
fp = concatenate([fp[-1:], fp, fp[:1]])
|
||||
x_arr = x_arr % period
|
||||
xp_arr = xp_arr % period
|
||||
xp_arr, fp_arr = lax.sort_key_val(xp_arr, fp_arr)
|
||||
xp_arr = concatenate([xp_arr[-1:] - period, xp_arr, xp_arr[:1] + period])
|
||||
fp_arr = concatenate([fp_arr[-1:], fp_arr, fp_arr[:1]])
|
||||
|
||||
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
|
||||
df = fp[i] - fp[i - 1]
|
||||
dx = xp[i] - xp[i - 1]
|
||||
delta = x - xp[i - 1]
|
||||
i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1)
|
||||
df = fp_arr[i] - fp_arr[i - 1]
|
||||
dx = xp_arr[i] - xp_arr[i - 1]
|
||||
delta = x_arr - xp_arr[i - 1]
|
||||
|
||||
epsilon = np.spacing(np.finfo(xp.dtype).eps)
|
||||
epsilon = np.spacing(np.finfo(xp_arr.dtype).eps)
|
||||
dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
|
||||
f = where(dx0, fp[i - 1], fp[i - 1] + (delta / where(dx0, 1, dx)) * df)
|
||||
f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df)
|
||||
|
||||
left_arr: ArrayLike = fp_arr[0] if left is None else left
|
||||
right_arr: ArrayLike = fp_arr[-1] if right is None else right
|
||||
|
||||
if period is None:
|
||||
f = where(x < xp[0], fp[0] if left is None else left, f)
|
||||
f = where(x > xp[-1], fp[-1] if right is None else right, f)
|
||||
f = where(x_arr < xp_arr[0], left_arr, f)
|
||||
f = where(x_arr > xp_arr[-1], right_arr, f)
|
||||
return f
|
||||
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
|
||||
size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Tuple[Array, ...]: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
||||
size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Union[Array, Tuple[Array, ...]]: ...
|
||||
|
||||
@_wraps(np.where,
|
||||
lax_description=_dedent("""
|
||||
At present, JAX does not support JIT-compilation of the single-argument form
|
||||
@ -1036,7 +1062,10 @@ def interp(x, xp, fp, left=None, right=None, period=None):
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def where(condition, x=None, y=None, *, size=None, fill_value=None):
|
||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
if x is None and y is None:
|
||||
_check_arraylike("where", condition)
|
||||
return nonzero(condition, size=size, fill_value=fill_value)
|
||||
@ -1094,6 +1123,13 @@ def bincount(x, weights=None, minlength=0, *, length=None):
|
||||
raise ValueError("shape of weights must match shape of x.")
|
||||
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
|
||||
) -> Tuple[Union[int, core.Tracer], ...]: ...
|
||||
|
||||
@_wraps(getattr(np, "broadcast_shapes", None))
|
||||
def broadcast_shapes(*shapes):
|
||||
if not shapes:
|
||||
@ -1102,17 +1138,22 @@ def broadcast_shapes(*shapes):
|
||||
return lax.broadcast_shapes(*shapes)
|
||||
|
||||
|
||||
broadcast_arrays = _wraps(np.broadcast_arrays, lax_description="""\
|
||||
@_wraps(np.broadcast_arrays, lax_description="""\
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")(_broadcast_arrays)
|
||||
""")
|
||||
def broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
||||
return _broadcast_arrays(*args)
|
||||
|
||||
|
||||
broadcast_to = _wraps(np.broadcast_to, lax_description="""\
|
||||
@_wraps(np.broadcast_to, lax_description="""\
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")(_broadcast_to)
|
||||
""")
|
||||
def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
|
||||
return _broadcast_to(array, shape)
|
||||
|
||||
|
||||
def _split(op, ary, indices_or_sections, axis=0):
|
||||
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
_check_arraylike(op, ary)
|
||||
ary = asarray(ary)
|
||||
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
|
||||
@ -1133,7 +1174,7 @@ def _split(op, ary, indices_or_sections, axis=0):
|
||||
else:
|
||||
indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections,
|
||||
f"in jax.numpy.{op} argument 1")
|
||||
part_size, r = _divmod(size, indices_or_sections)
|
||||
part_size, r = _divmod(size, indices_or_sections) # type: ignore[misc]
|
||||
if r == 0:
|
||||
split_indices = np.arange(indices_or_sections + 1,
|
||||
dtype=np.int64) * part_size
|
||||
@ -1150,12 +1191,12 @@ def _split(op, ary, indices_or_sections, axis=0):
|
||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||
|
||||
@_wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
|
||||
def split(ary, indices_or_sections, axis: int = 0):
|
||||
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
||||
return _split("split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
def _split_on_axis(op, axis):
|
||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
|
||||
@_wraps(getattr(np, op), update_doc=False)
|
||||
def f(ary, indices_or_sections):
|
||||
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
|
||||
return _split(op, ary, indices_or_sections, axis=axis)
|
||||
return f
|
||||
|
||||
@ -1164,12 +1205,13 @@ hsplit = _split_on_axis("hsplit", axis=1)
|
||||
dsplit = _split_on_axis("dsplit", axis=2)
|
||||
|
||||
@_wraps(np.array_split)
|
||||
def array_split(ary, indices_or_sections, axis: int = 0):
|
||||
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@_wraps(np.clip, skip_params=['out'])
|
||||
@jit
|
||||
def clip(a, a_min=None, a_max=None, out=None):
|
||||
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
|
||||
a_max: Optional[ArrayLike] = None, out: None = None) -> Array:
|
||||
_check_arraylike("clip", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
|
||||
@ -1179,11 +1221,11 @@ def clip(a, a_min=None, a_max=None, out=None):
|
||||
a = maximum(a_min, a)
|
||||
if a_max is not None:
|
||||
a = minimum(a_max, a)
|
||||
return a
|
||||
return asarray(a)
|
||||
|
||||
@_wraps(np.around, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('decimals',))
|
||||
def round(a, decimals=0, out=None):
|
||||
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
|
||||
_check_arraylike("round", a)
|
||||
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
|
||||
if out is not None:
|
||||
@ -1193,9 +1235,9 @@ def round(a, decimals=0, out=None):
|
||||
if decimals < 0:
|
||||
raise NotImplementedError(
|
||||
"integer np.round not implemented for decimals < 0")
|
||||
return a # no-op on integer types
|
||||
return asarray(a) # no-op on integer types
|
||||
|
||||
def _round_float(x):
|
||||
def _round_float(x: ArrayLike) -> Array:
|
||||
if decimals == 0:
|
||||
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
|
||||
|
||||
@ -1219,7 +1261,7 @@ round_ = round
|
||||
|
||||
@_wraps(np.fix, skip_params=['out'])
|
||||
@jit
|
||||
def fix(x, out=None):
|
||||
def fix(x: ArrayLike, out: None = None) -> Array:
|
||||
_check_arraylike("fix", x)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
|
||||
@ -1229,7 +1271,9 @@ def fix(x, out=None):
|
||||
|
||||
@_wraps(np.nan_to_num)
|
||||
@jit
|
||||
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
|
||||
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
|
||||
posinf: Optional[ArrayLike] = None,
|
||||
neginf: Optional[ArrayLike] = None) -> Array:
|
||||
del copy
|
||||
_check_arraylike("nan_to_num", x)
|
||||
dtype = _dtype(x)
|
||||
@ -1240,15 +1284,16 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
|
||||
info = finfo(dtypes.canonicalize_dtype(dtype))
|
||||
posinf = info.max if posinf is None else posinf
|
||||
neginf = info.min if neginf is None else neginf
|
||||
x = where(isnan(x), array(nan, dtype=x.dtype), x)
|
||||
x = where(isposinf(x), array(posinf, dtype=x.dtype), x)
|
||||
x = where(isneginf(x), array(neginf, dtype=x.dtype), x)
|
||||
return x
|
||||
out = where(isnan(x), asarray(nan, dtype=dtype), x)
|
||||
out = where(isposinf(out), asarray(posinf, dtype=dtype), out)
|
||||
out = where(isneginf(out), asarray(neginf, dtype=dtype), out)
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(np.allclose)
|
||||
@partial(jit, static_argnames=('equal_nan',))
|
||||
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
||||
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
|
||||
atol: ArrayLike = 1e-08, equal_nan: bool = False):
|
||||
_check_arraylike("allclose", a, b)
|
||||
return all(isclose(a, b, rtol, atol, equal_nan))
|
||||
|
||||
@ -1269,31 +1314,34 @@ fill_value : array_like, optional
|
||||
"""
|
||||
|
||||
@_wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
|
||||
def nonzero(a, *, size=None, fill_value=None):
|
||||
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
|
||||
) -> Tuple[Array, ...]:
|
||||
_check_arraylike("nonzero", a)
|
||||
a = atleast_1d(a)
|
||||
mask = a if a.dtype == bool else (a != 0)
|
||||
arr = atleast_1d(a)
|
||||
del a
|
||||
mask = arr if arr.dtype == bool else (arr != 0)
|
||||
if size is None:
|
||||
size = mask.sum()
|
||||
size = core.concrete_or_error(operator.index, size,
|
||||
"The size argument of jnp.nonzero must be statically specified "
|
||||
"to use jnp.nonzero within JAX transformations.")
|
||||
if a.size == 0 or size == 0:
|
||||
return tuple(zeros(size, int) for dim in a.shape)
|
||||
if arr.size == 0 or size == 0:
|
||||
return tuple(zeros(size, int) for dim in arr.shape)
|
||||
flat_indices = cumsum(bincount(cumsum(mask), length=size))
|
||||
strides = (np.cumprod(a.shape[::-1])[::-1] // a.shape).astype(int_)
|
||||
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))
|
||||
strides = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(int_)
|
||||
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape))
|
||||
if size is not None and fill_value is not None:
|
||||
if not isinstance(fill_value, tuple):
|
||||
fill_value = a.ndim * (fill_value,)
|
||||
if _shape(fill_value) != (a.ndim,):
|
||||
raise ValueError(f"fill_value must be a scalar or a tuple of length {a.ndim}; got {fill_value}")
|
||||
fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,)
|
||||
if _any(_shape(val) != () for val in fill_value_tup):
|
||||
raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}")
|
||||
fill_mask = arange(size) >= mask.sum()
|
||||
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value, out))
|
||||
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
|
||||
return out
|
||||
|
||||
@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
|
||||
def flatnonzero(a, *, size=None, fill_value=None):
|
||||
def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array:
|
||||
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
|
||||
|
||||
|
||||
|
@ -414,8 +414,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
|
||||
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
||||
# materialize the broadcast forms of scalar arguments.
|
||||
@api.jit
|
||||
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None) -> Array:
|
||||
def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments should "
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
|
@ -165,7 +165,7 @@ def _minimize_lbfgs(
|
||||
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
|
||||
|
||||
# replacements for next iteration
|
||||
status = 0
|
||||
status = jnp.array(0)
|
||||
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
|
||||
status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore
|
||||
status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore
|
||||
|
@ -5134,7 +5134,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
||||
unsupported_params = {
|
||||
'asarray': ['like'],
|
||||
'broadcast_to': ['subok', 'array'],
|
||||
'broadcast_to': ['subok'],
|
||||
'clip': ['kwargs'],
|
||||
'copy': ['subok'],
|
||||
'corrcoef': ['ddof', 'bias', 'dtype'],
|
||||
@ -5164,7 +5164,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
extra_params = {
|
||||
'broadcast_to': ['arr'],
|
||||
'einsum': ['precision'],
|
||||
'einsum_path': ['subscripts'],
|
||||
'take_along_axis': ['mode'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user