Merge pull request #14957 from jakevdp:leading-underscore

PiperOrigin-RevId: 516296407
This commit is contained in:
jax authors 2023-03-13 12:57:20 -07:00
commit 5768cdf796
52 changed files with 465 additions and 465 deletions

View File

@ -23,7 +23,7 @@ from jax import lax
from jax import numpy as jnp
from jax._src import core
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _promote_dtypes_inexact
from jax._src.numpy.util import promote_dtypes_inexact
def _fill_lanczos_kernel(radius, x):
@ -243,8 +243,8 @@ def scale_and_translate(image, shape: core.Shape,
assert isinstance(method, ResizeMethod)
kernel = _kernels[method]
image, = _promote_dtypes_inexact(image)
scale, translation = _promote_dtypes_inexact(scale, translation)
image, = promote_dtypes_inexact(image)
scale, translation = promote_dtypes_inexact(scale, translation)
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
kernel, antialias, precision)
@ -281,7 +281,7 @@ def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
assert isinstance(method, ResizeMethod)
kernel = _kernels[method]
image, = _promote_dtypes_inexact(image)
image, = promote_dtypes_inexact(image)
# Skip dimensions that have scale=1 and translation=0, this is only possible
# since all of the current resize methods (kernels) are interpolating, so the
# output = input under an identity warp.

View File

@ -30,7 +30,7 @@ from jax._src.interpreters import batching
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
__all__ = [
"fft",
@ -61,9 +61,9 @@ def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int])
if typ == xla_client.FftType.RFFT:
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x, = _promote_dtypes_inexact(x)
x, = promote_dtypes_inexact(x)
else:
x, = _promote_dtypes_complex(x)
x, = promote_dtypes_complex(x)
if len(fft_lengths) == 0:
# XLA FFT doesn't support 0-rank.
return x

View File

@ -21,7 +21,7 @@ from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
@ -43,7 +43,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
s: Optional[Shape], axes: Optional[Sequence[int]],
norm: Optional[str]) -> Array:
full_name = "jax.numpy.fft." + func_name
_check_arraylike(full_name, a)
check_arraylike(full_name, a)
arr = jnp.asarray(a)
if s is not None:
@ -293,7 +293,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
@_wraps(np.fft.fftshift)
def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
_check_arraylike("fftshift", x)
check_arraylike("fftshift", x)
x = jnp.asarray(x)
shift: Union[int, Sequence[int]]
if axes is None:
@ -309,7 +309,7 @@ def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Arra
@_wraps(np.fft.ifftshift)
def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
_check_arraylike("ifftshift", x)
check_arraylike("ifftshift", x)
x = jnp.asarray(x)
shift: Union[int, Sequence[int]]
if axes is None:

View File

@ -17,7 +17,7 @@ from typing import Any, Iterable, List, Tuple, Union
import jax
from jax._src import core
from jax._src.numpy.util import _promote_dtypes
from jax._src.numpy.util import promote_dtypes
from jax._src.numpy.lax_numpy import (
arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose
)
@ -54,7 +54,7 @@ class _IndexGrid(abc.ABC):
return _make_1d_grid_from_slice(key, op_name=self.op_name)
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key)
with jax.numpy_dtype_promotion('standard'):
output = _promote_dtypes(*output)
output = promote_dtypes(*output)
output_arr = meshgrid(*output, indexing='ij', sparse=self.sparse)
if self.sparse:
return output_arr

View File

@ -334,11 +334,11 @@ def result_type(*args: ArrayLike) -> DType:
@partial(jit, static_argnames=('axis',))
def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array:
if x is None:
util._check_arraylike('trapz', y)
y_arr, = util._promote_dtypes_inexact(y)
util.check_arraylike('trapz', y)
y_arr, = util.promote_dtypes_inexact(y)
else:
util._check_arraylike('trapz', y, x)
y_arr, x_arr = util._promote_dtypes_inexact(y, x)
util.check_arraylike('trapz', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx = diff(x_arr)
else:
@ -350,7 +350,7 @@ def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis
@util._wraps(np.trunc, module='numpy')
@jit
def trunc(x: ArrayLike) -> Array:
util._check_arraylike('trunc', x)
util.check_arraylike('trunc', x)
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
@ -358,7 +358,7 @@ def trunc(x: ArrayLike) -> Array:
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array:
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = util._promote_dtypes_inexact(x, y)
x, y = util.promote_dtypes_inexact(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
@ -391,7 +391,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> A
@partial(jit, static_argnames=('mode', 'precision'))
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision: PrecisionLike = None) -> Array:
util._check_arraylike("convolve", a, v)
util.check_arraylike("convolve", a, v)
return _conv(asarray(a), asarray(v), mode, 'convolve', precision)
@ -399,7 +399,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
@partial(jit, static_argnames=('mode', 'precision'))
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision: PrecisionLike = None) -> Array:
util._check_arraylike("correlate", a, v)
util.check_arraylike("correlate", a, v)
return _conv(asarray(a), asarray(v), mode, 'correlate', precision)
@ -410,7 +410,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
del weights # unused, because string bins is not supported.
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
util._check_arraylike("histogram_bin_edges", a, bins)
util.check_arraylike("histogram_bin_edges", a, bins)
arr = ravel(a)
dtype = dtypes.to_inexact_dtype(arr.dtype)
if _ndim(bins) == 1:
@ -435,14 +435,14 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array]:
if weights is None:
util._check_arraylike("histogram", a, bins)
a = ravel(*util._promote_dtypes_inexact(a))
util.check_arraylike("histogram", a, bins)
a = ravel(*util.promote_dtypes_inexact(a))
weights = ones_like(a)
else:
util._check_arraylike("histogram", a, bins, weights)
util.check_arraylike("histogram", a, bins, weights)
if shape(a) != shape(weights):
raise ValueError("weights should have the same shape as a.")
a, weights = map(ravel, util._promote_dtypes_inexact(a, weights))
a, weights = map(ravel, util.promote_dtypes_inexact(a, weights))
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
@ -458,7 +458,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLik
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array, Array]:
util._check_arraylike("histogram2d", x, y)
util.check_arraylike("histogram2d", x, y)
try:
N = len(bins) # type: ignore[arg-type]
except TypeError:
@ -478,13 +478,13 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, List[Array]]:
if weights is None:
util._check_arraylike("histogramdd", sample)
sample, = util._promote_dtypes_inexact(sample)
util.check_arraylike("histogramdd", sample)
sample, = util.promote_dtypes_inexact(sample)
else:
util._check_arraylike("histogramdd", sample, weights)
util.check_arraylike("histogramdd", sample, weights)
if shape(weights) != shape(sample)[:1]:
raise ValueError("should have one weight for each sample.")
sample, weights = util._promote_dtypes_inexact(sample, weights)
sample, weights = util.promote_dtypes_inexact(sample, weights)
N, D = shape(sample)
if range is not None and (
@ -538,7 +538,7 @@ view of the input.
@util._wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array:
util._stackable(a) or util._check_arraylike("transpose", a)
util._stackable(a) or util.check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
return lax.transpose(a, axes_)
@ -547,7 +547,7 @@ def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array:
@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
util._check_arraylike("rot90", m)
util.check_arraylike("rot90", m)
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
ax2 = _canonicalize_axis(ax2, ndim(m))
@ -569,7 +569,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
util._check_arraylike("flip", m)
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',))
@ -582,13 +582,13 @@ def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array
@util._wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC)
def fliplr(m: ArrayLike) -> Array:
util._check_arraylike("fliplr", m)
util.check_arraylike("fliplr", m)
return _flip(asarray(m), 1)
@util._wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC)
def flipud(m: ArrayLike) -> Array:
util._check_arraylike("flipud", m)
util.check_arraylike("flipud", m)
return _flip(asarray(m), 0)
@util._wraps(np.iscomplex)
@ -623,7 +623,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array:
def diff(a: ArrayLike, n: int = 1, axis: int = -1,
prepend: Optional[ArrayLike] = None,
append: Optional[ArrayLike] = None) -> Array:
util._check_arraylike("diff", a)
util.check_arraylike("diff", a)
arr = asarray(a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff")
@ -639,7 +639,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
combined: List[Array] = []
if prepend is not None:
util._check_arraylike("diff", prepend)
util.check_arraylike("diff", prepend)
if isscalar(prepend):
shape = list(arr.shape)
shape[axis] = 1
@ -649,7 +649,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
combined.append(arr)
if append is not None:
util._check_arraylike("diff", append)
util.check_arraylike("diff", append)
if isscalar(append):
shape = list(arr.shape)
shape[axis] = 1
@ -682,14 +682,14 @@ loses precision.
@jit
def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None,
to_begin: Optional[ArrayLike] = None) -> Array:
util._check_arraylike("ediff1d", ary)
util.check_arraylike("ediff1d", ary)
arr = ravel(ary)
result = lax.sub(arr[1:], arr[:-1])
if to_begin is not None:
util._check_arraylike("ediff1d", to_begin)
util.check_arraylike("ediff1d", to_begin)
result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result))
if to_end is not None:
util._check_arraylike("ediff1d", to_end)
util.check_arraylike("ediff1d", to_end)
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
return result
@ -701,7 +701,7 @@ def gradient(f: ArrayLike, *varargs: ArrayLike,
edge_order: Optional[int] = None) -> Union[Array, List[Array]]:
if edge_order is not None:
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
a, *spacing = util._promote_args_inexact("gradient", f, *varargs)
a, *spacing = util.promote_args_inexact("gradient", f, *varargs)
def gradient_along_axis(a, h, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
@ -750,7 +750,7 @@ def isrealobj(x: Any) -> bool:
@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a: ArrayLike, newshape: Union[DimSize, Shape], order: str = "C") -> Array:
util._stackable(a) or util._check_arraylike("reshape", a)
util._stackable(a) or util.check_arraylike("reshape", a)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
@ -817,7 +817,7 @@ def _transpose(a: Array, *args: Any) -> Array:
@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a: ArrayLike, order: str = "C") -> Array:
util._stackable(a) or util._check_arraylike("ravel", a)
util._stackable(a) or util.check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@ -828,7 +828,7 @@ def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...],
mode: str = 'raise', order: str = 'C') -> Array:
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
util._check_arraylike("ravel_multi_index", *multi_index)
util.check_arraylike("ravel_multi_index", *multi_index)
multi_index_arr = [asarray(i) for i in multi_index]
for index in multi_index_arr:
if mode == 'raise':
@ -868,7 +868,7 @@ and out-of-bounds indices are clipped into the valid range.
@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
util._check_arraylike("unravel_index", indices)
util.check_arraylike("unravel_index", indices)
indices_arr = asarray(indices)
# Note: we do not convert shape to an array, because it may be passed as a
# tuple of weakly-typed values, and asarray() would strip these weak types.
@ -889,7 +889,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
@util._wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
util._check_arraylike("resize", a)
util.check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)
if _any(dim_length < 0 for dim_length in new_shape):
@ -908,7 +908,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
util._check_arraylike("squeeze", a)
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True)
@ -924,7 +924,7 @@ def _squeeze(a: Array, axis: Tuple[int]) -> Array:
@util._wraps(np.expand_dims)
def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array:
util._stackable(a) or util._check_arraylike("expand_dims", a)
util._stackable(a) or util.check_arraylike("expand_dims", a)
axis = _ensure_index_tuple(axis)
if hasattr(a, "expand_dims"):
return a.expand_dims(axis) # type: ignore
@ -934,7 +934,7 @@ def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array:
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util._check_arraylike("swapaxes", a)
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, list(perm))
@ -943,7 +943,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
@util._wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC)
def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]]) -> Array:
util._check_arraylike("moveaxis", a)
util.check_arraylike("moveaxis", a)
return _moveaxis(asarray(a), _ensure_index_tuple(source),
_ensure_index_tuple(destination))
@ -964,7 +964,7 @@ def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -
@partial(jit, static_argnames=('equal_nan',))
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08,
equal_nan: bool = False) -> Array:
a, b = util._promote_args("isclose", a, b)
a, b = util.promote_args("isclose", a, b)
dtype = _dtype(a)
if issubdtype(dtype, inexact):
if issubdtype(dtype, complexfloating):
@ -1006,11 +1006,11 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Optional[ArrayLike] = None,
right: Optional[ArrayLike] = None,
period: Optional[ArrayLike] = None) -> Array:
util._check_arraylike("interp", x, xp, fp)
util.check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x_arr, xp_arr = util._promote_dtypes_inexact(x, xp)
fp_arr, = util._promote_dtypes_inexact(fp)
x_arr, xp_arr = util.promote_dtypes_inexact(x, xp)
fp_arr, = util.promote_dtypes_inexact(fp)
del x, xp, fp
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
@ -1091,10 +1091,10 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]:
if x is None and y is None:
util._check_arraylike("where", condition)
util.check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
else:
util._check_arraylike("where", condition, x, y)
util.check_arraylike("where", condition, x, y)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
return util._where(condition, x, y)
@ -1107,7 +1107,7 @@ def select(condlist, choicelist, default=0):
raise ValueError(msg.format(len(condlist), len(choicelist)))
if len(condlist) == 0:
raise ValueError("condlist must be non-empty")
choices = util._promote_dtypes(default, *choicelist)
choices = util.promote_dtypes(default, *choicelist)
choicelist = choices[1:]
output = choices[0]
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
@ -1126,7 +1126,7 @@ negative values, ``jax.numpy.bincount`` clips negative values to zero.
""")
def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None,
minlength: int = 0, *, length: Optional[int] = None) -> Array:
util._check_arraylike("bincount", x)
util.check_arraylike("bincount", x)
if not issubdtype(_dtype(x), integer):
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
if ndim(x) != 1:
@ -1178,7 +1178,7 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
axis: int = 0) -> List[Array]:
util._check_arraylike(op, ary)
util.check_arraylike(op, ary)
ary = asarray(ary)
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
size = ary.shape[axis]
@ -1223,7 +1223,7 @@ def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayL
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
# for 1-D array, hsplit becomes vsplit
nonlocal axis
util._check_arraylike(op, ary)
util.check_arraylike(op, ary)
a = asarray(ary)
if axis == 1 and len(a.shape) == 1:
axis = 0
@ -1242,7 +1242,7 @@ def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis
@jit
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
a_max: Optional[ArrayLike] = None, out: None = None) -> Array:
util._check_arraylike("clip", a)
util.check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
@ -1256,7 +1256,7 @@ def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
@util._wraps(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
util._check_arraylike("round", a)
util.check_arraylike("round", a)
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
@ -1292,7 +1292,7 @@ round_ = round
@util._wraps(np.fix, skip_params=['out'])
@jit
def fix(x: ArrayLike, out: None = None) -> Array:
util._check_arraylike("fix", x)
util.check_arraylike("fix", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
zero = _lax_const(x, 0)
@ -1305,7 +1305,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: Optional[ArrayLike] = None,
neginf: Optional[ArrayLike] = None) -> Array:
del copy
util._check_arraylike("nan_to_num", x)
util.check_arraylike("nan_to_num", x)
dtype = _dtype(x)
if not issubdtype(dtype, inexact):
return asarray(x)
@ -1326,7 +1326,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
@partial(jit, static_argnames=('equal_nan',))
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array:
util._check_arraylike("allclose", a, b)
util.check_arraylike("allclose", a, b)
return reductions.all(isclose(a, b, rtol, atol, equal_nan))
@ -1349,7 +1349,7 @@ fill_value : array_like, optional
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]:
util._check_arraylike("nonzero", a)
util.check_arraylike("nonzero", a)
arr = atleast_1d(a)
del a
mask = arr if arr.dtype == bool else (arr != 0)
@ -1382,12 +1382,12 @@ def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
@partial(jit, static_argnames=('axis',))
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
util._check_arraylike("unwrap", p)
util.check_arraylike("unwrap", p)
p = asarray(p)
if issubdtype(p.dtype, np.complexfloating):
raise ValueError("jnp.unwrap does not support complex inputs.")
if p.shape[axis] == 0:
return util._promote_dtypes_inexact(p)[0]
return util.promote_dtypes_inexact(p)[0]
if discont is None:
discont = period / 2
interval = period / 2
@ -1714,7 +1714,7 @@ the modified array. This is because Jax arrays are immutable.
""")
def pad(array: ArrayLike, pad_width: PadValueLike[int],
mode: Union[str, Callable[..., Any]] = "constant", **kwargs) -> Array:
util._check_arraylike("pad", array)
util.check_arraylike("pad", array)
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
if pad_width and not _all(core.is_dim(p[0]) and core.is_dim(p[1])
for p in pad_width):
@ -1764,7 +1764,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
util._stackable(*arrays) or util._check_arraylike("stack", *arrays)
util._stackable(*arrays) or util.check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
@ -1776,7 +1776,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
@util._wraps(np.tile)
def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
util._stackable(A) or util._check_arraylike("tile", A)
util._stackable(A) or util.check_arraylike("tile", A)
try:
iter(reps) # type: ignore[arg-type]
except TypeError:
@ -1811,7 +1811,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(arrays, (np.ndarray, Array)):
return _concatenate_array(arrays, axis, dtype=dtype)
util._stackable(*arrays) or util._check_arraylike("concatenate", *arrays)
util._stackable(*arrays) or util.check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
@ -1822,7 +1822,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr]
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays_out = util._promote_dtypes(*arrays)
arrays_out = util.promote_dtypes(*arrays)
else:
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
@ -1882,7 +1882,7 @@ def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util._check_arraylike('choose', a, *choices)
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
raise ValueError("`a` array must be integer typed")
N = len(choices)
@ -2072,14 +2072,14 @@ def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = No
@util._wraps(np.copy, lax_description=_ARRAY_DOC)
def copy(a: ArrayLike, order: Optional[str] = None) -> Array:
util._check_arraylike("copy", a)
util.check_arraylike("copy", a)
return array(a, copy=True, order=order)
@util._wraps(np.zeros_like)
def zeros_like(a: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
util._check_arraylike("zeros_like", a)
util.check_arraylike("zeros_like", a)
dtypes.check_user_dtype_supported(dtype, "zeros_like")
if shape is not None:
shape = canonicalize_shape(shape)
@ -2089,7 +2089,7 @@ def zeros_like(a: ArrayLike, dtype: Optional[DTypeLike] = None,
@util._wraps(np.ones_like)
def ones_like(a: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
util._check_arraylike("ones_like", a)
util.check_arraylike("ones_like", a)
dtypes.check_user_dtype_supported(dtype, "ones_like")
if shape is not None:
shape = canonicalize_shape(shape)
@ -2101,7 +2101,7 @@ Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty_like(prototype: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
util._check_arraylike("empty_like", prototype)
util.check_arraylike("empty_like", prototype)
dtypes.check_user_dtype_supported(dtype, "empty_like")
return zeros_like(prototype, dtype=dtype, shape=shape)
@ -2110,7 +2110,7 @@ def empty_like(prototype: ArrayLike, dtype: Optional[DTypeLike] = None,
def full(shape: Any, fill_value: ArrayLike,
dtype: Optional[DTypeLike] = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "full")
util._check_arraylike("full", fill_value)
util.check_arraylike("full", fill_value)
if ndim(fill_value) == 0:
shape = canonicalize_shape(shape)
return lax.full(shape, fill_value, dtype)
@ -2122,7 +2122,7 @@ def full(shape: Any, fill_value: ArrayLike,
def full_like(a: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "full_like")
util._check_arraylike("full_like", a, fill_value)
util.check_arraylike("full_like", a, fill_value)
if shape is not None:
shape = canonicalize_shape(shape)
if ndim(fill_value) == 0:
@ -2340,7 +2340,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
dtypes.check_user_dtype_supported(dtype, "linspace")
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
util._check_arraylike("linspace", start, stop)
util.check_arraylike("linspace", start, stop)
if dtype is None:
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
@ -2407,7 +2407,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
util._check_arraylike("logspace", start, stop)
util.check_arraylike("logspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
lin = linspace(start, stop, num,
@ -2431,7 +2431,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
util._check_arraylike("geomspace", start, stop)
util.check_arraylike("geomspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
@ -2449,7 +2449,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
@util._wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> List[Array]:
util._check_arraylike("meshgrid", *xi)
util.check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
raise ValueError("jax.numpy.meshgrid only supports copy=True")
@ -2471,7 +2471,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
@util._wraps(np.i0)
@jit
def i0(x: ArrayLike) -> Array:
x_arr, = util._promote_args_inexact("i0", x)
x_arr, = util.promote_args_inexact("i0", x)
if not issubdtype(x_arr.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}")
x_arr = lax.abs(x_arr)
@ -2480,7 +2480,7 @@ def i0(x: ArrayLike) -> Array:
@util._wraps(np.ix_)
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
util._check_arraylike("ix", *args)
util.check_arraylike("ix", *args)
n = len(args)
output = []
for i, a in enumerate(args):
@ -2542,8 +2542,8 @@ will be repeated.
@util._wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
total_repeat_length: Optional[int] = None) -> Array:
util._check_arraylike("repeat", a)
core.is_special_dim_size(repeats) or util._check_arraylike("repeat", repeats)
util.check_arraylike("repeat", a)
core.is_special_dim_size(repeats) or util.check_arraylike("repeat", repeats)
if axis is None:
a = ravel(a)
@ -2631,7 +2631,7 @@ def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) ->
@util._wraps(np.tril)
@partial(jit, static_argnames=('k',))
def tril(m: ArrayLike, k: int = 0) -> Array:
util._check_arraylike("tril", m)
util.check_arraylike("tril", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
@ -2643,7 +2643,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
@util._wraps(np.triu, update_doc=False)
@partial(jit, static_argnames=('k',))
def triu(m: ArrayLike, k: int = 0) -> Array:
util._check_arraylike("triu", m)
util.check_arraylike("triu", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
@ -2656,7 +2656,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype'))
def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
util._check_arraylike("trace", a)
util.check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
dtypes.check_user_dtype_supported(dtype, "trace")
@ -2746,7 +2746,7 @@ def diag_indices(n, ndim=2):
@util._wraps(np.diag_indices_from)
def diag_indices_from(arr):
util._check_arraylike("diag_indices_from", arr)
util.check_arraylike("diag_indices_from", arr)
if not arr.ndim >= 2:
raise ValueError("input array must be at least 2-d")
@ -2758,7 +2758,7 @@ def diag_indices_from(arr):
@util._wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('offset', 'axis1', 'axis2'))
def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
util._check_arraylike("diagonal", a)
util.check_arraylike("diagonal", a)
a_shape = shape(a)
if ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
@ -2779,7 +2779,7 @@ def diag(v, k=0):
@partial(jit, static_argnames=('k',))
def _diag(v, k):
util._check_arraylike("diag", v)
util.check_arraylike("diag", v)
v_shape = shape(v)
if len(v_shape) == 1:
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
@ -2799,7 +2799,7 @@ return a scalar depending on the type of v.
@util._wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC)
def diagflat(v, k=0):
util._check_arraylike("diagflat", v)
util.check_arraylike("diagflat", v)
v = ravel(v)
v_length = len(v)
adj_length = v_length + _abs(k)
@ -2848,7 +2848,7 @@ def append(arr, values, axis: Optional[int] = None):
@util._wraps(np.delete)
def delete(arr, obj, axis=None):
util._check_arraylike("delete", arr)
util.check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
axis = 0
@ -2872,7 +2872,7 @@ def delete(arr, obj, axis=None):
# Case 3: obj is an array
# NB: pass both arrays to check for appropriate error message.
util._check_arraylike("delete", arr, obj)
util.check_arraylike("delete", arr, obj)
obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()")
if issubdtype(obj.dtype, integer):
@ -2891,7 +2891,7 @@ def delete(arr, obj, axis=None):
@util._wraps(np.insert)
def insert(arr, obj, values, axis=None):
util._check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
arr = asarray(arr)
values = asarray(values)
@ -2972,8 +2972,8 @@ def apply_over_axes(func, a, axes):
@util._wraps(np.dot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
util._check_arraylike("dot", a, b)
a, b = util._promote_dtypes(a, b)
util.check_arraylike("dot", a, b)
a, b = util.promote_dtypes(a, b)
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
@ -2991,14 +2991,14 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
util._check_arraylike("matmul", a, b)
util.check_arraylike("matmul", a, b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)
a, b = util._promote_dtypes(a, b)
a, b = util.promote_dtypes(a, b)
a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
@ -3054,7 +3054,7 @@ def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
@util._wraps(np.vdot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def vdot(a, b, *, precision=None):
util._check_arraylike("vdot", a, b)
util.check_arraylike("vdot", a, b)
if issubdtype(_dtype(a), complexfloating):
a = ufuncs.conj(a)
return dot(ravel(a), ravel(b), precision=precision)
@ -3062,11 +3062,11 @@ def vdot(a, b, *, precision=None):
@util._wraps(np.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, *, precision=None):
util._check_arraylike("tensordot", a, b)
util.check_arraylike("tensordot", a, b)
a_ndim = ndim(a)
b_ndim = ndim(b)
a, b = util._promote_dtypes(a, b)
a, b = util.promote_dtypes(a, b)
if type(axes) is int:
if axes > _min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
@ -3159,7 +3159,7 @@ def _removechars(s, chars):
def _einsum(operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
precision):
operands = list(util._promote_dtypes(*operands))
operands = list(util.promote_dtypes(*operands))
def sum(x, axes):
return lax.reduce(x, np.array(0, x.dtype),
lax.add if x.dtype != bool_ else lax.bitwise_or, axes)
@ -3304,7 +3304,7 @@ def inner(a, b, *, precision=None):
def outer(a, b, out=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
a, b = util._promote_dtypes(a, b)
a, b = util.promote_dtypes(a, b)
return ravel(a)[:, None] * ravel(b)[None, :]
@util._wraps(np.cross)
@ -3337,7 +3337,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
@util._wraps(np.kron)
@jit
def kron(a, b):
a, b = util._promote_dtypes(a, b)
a, b = util.promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = expand_dims(a, range(ndim(b) - ndim(a)))
elif ndim(b) < ndim(a):
@ -3351,7 +3351,7 @@ def kron(a, b):
@util._wraps(np.vander)
@partial(jit, static_argnames=('N', 'increasing'))
def vander(x, N=None, increasing=False):
util._check_arraylike("vander", x)
util.check_arraylike("vander", x)
x = asarray(x)
if x.ndim != 1:
raise ValueError("x must be a one-dimensional array")
@ -3400,7 +3400,7 @@ def argwhere(a, *, size=None, fill_value=None):
@util._wraps(np.argmax, skip_params=['out'])
def argmax(a: ArrayLike, axis: Optional[int] = None, out=None, keepdims=None) -> Array:
util._check_arraylike("argmax", a)
util.check_arraylike("argmax", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.")
return _argmax(asarray(a), None if axis is None else operator.index(axis),
@ -3421,7 +3421,7 @@ def _argmax(a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Arr
@util._wraps(np.argmin, skip_params=['out'])
def argmin(a: ArrayLike, axis: Optional[int] = None, out=None, keepdims=None) -> Array:
util._check_arraylike("argmin", a)
util.check_arraylike("argmin", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.")
return _argmin(asarray(a), None if axis is None else operator.index(axis),
@ -3454,7 +3454,7 @@ def nanargmax(a, axis: Optional[int] = None, out : Any = None, keepdims : Option
@partial(jit, static_argnames=('axis', 'keepdims'))
def _nanargmax(a, axis: Optional[int] = None, keepdims: bool = False):
util._check_arraylike("nanargmax", a)
util.check_arraylike("nanargmax", a)
if not issubdtype(_dtype(a), inexact):
return argmax(a, axis=axis, keepdims=keepdims)
nan_mask = ufuncs.isnan(a)
@ -3470,7 +3470,7 @@ def nanargmin(a, axis: Optional[int] = None, out : Any = None, keepdims : Option
@partial(jit, static_argnames=('axis', 'keepdims'))
def _nanargmin(a, axis: Optional[int] = None, keepdims : bool = False):
util._check_arraylike("nanargmin", a)
util.check_arraylike("nanargmin", a)
if not issubdtype(_dtype(a), inexact):
return argmin(a, axis=axis, keepdims=keepdims)
nan_mask = ufuncs.isnan(a)
@ -3482,7 +3482,7 @@ def _nanargmin(a, axis: Optional[int] = None, keepdims : bool = False):
@util._wraps(np.sort)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
util._check_arraylike("sort", a)
util.check_arraylike("sort", a)
if kind != 'quicksort':
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
@ -3496,7 +3496,7 @@ def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
@util._wraps(np.sort_complex)
@jit
def sort_complex(a):
util._check_arraylike("sort_complex", a)
util.check_arraylike("sort_complex", a)
a = lax.sort(a, dimension=0)
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
@ -3524,7 +3524,7 @@ a warning and be treated as if they were :code:`'stable'`.
@util._wraps(np.argsort, lax_description=_ARGSORT_DOC)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
def argsort(a: ArrayLike, axis: Optional[int] = -1, kind: str = 'stable', order=None) -> Array:
util._check_arraylike("argsort", a)
util.check_arraylike("argsort", a)
arr = asarray(a)
if kind != 'stable':
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
@ -3561,7 +3561,7 @@ NaNs which have the negative bit set are sorted to the beginning of the array.
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
util._check_arraylike("partition", a)
util.check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.partition for complex dtype is not implemented.")
@ -3587,7 +3587,7 @@ NaNs which have the negative bit set are sorted to the beginning of the array.
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
util._check_arraylike("partition", a)
util.check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
@ -3633,7 +3633,7 @@ def _roll(a, shift, axis):
@util._wraps(np.roll)
def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None):
util._check_arraylike("roll", a,)
util.check_arraylike("roll", a,)
if isinstance(axis, list):
axis = tuple(axis)
return _roll(a, shift, axis)
@ -3642,7 +3642,7 @@ def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None):
@util._wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis', 'start'))
def rollaxis(a, axis: int, start=0):
util._check_arraylike("rollaxis", a)
util.check_arraylike("rollaxis", a)
start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()")
a_ndim = ndim(a)
axis = _canonicalize_axis(axis, a_ndim)
@ -3658,7 +3658,7 @@ def rollaxis(a, axis: int, start=0):
@util._wraps(np.packbits)
@partial(jit, static_argnames=('axis', 'bitorder'))
def packbits(a, axis: Optional[int] = None, bitorder='big'):
util._check_arraylike("packbits", a)
util.check_arraylike("packbits", a)
if not (issubdtype(_dtype(a), integer) or issubdtype(_dtype(a), bool_)):
raise TypeError('Expected an input array of integer or boolean data type')
if bitorder not in ['little', 'big']:
@ -3686,7 +3686,7 @@ def packbits(a, axis: Optional[int] = None, bitorder='big'):
@util._wraps(np.unpackbits)
@partial(jit, static_argnames=('axis', 'count', 'bitorder'))
def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
util._check_arraylike("unpackbits", a)
util.check_arraylike("unpackbits", a)
if _dtype(a) != uint8:
raise TypeError("Expected an input array of unsigned byte data type")
if bitorder not in ['little', 'big']:
@ -3735,7 +3735,7 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
util._check_arraylike("take", a, indices)
util.check_arraylike("take", a, indices)
a = asarray(a)
indices = asarray(indices)
@ -3813,7 +3813,7 @@ indexing in JAX.
@partial(jit, static_argnames=('axis', 'mode'))
def take_along_axis(arr, indices, axis: Optional[int],
mode: Optional[Union[str, lax.GatherScatterMode]] = None):
util._check_arraylike("take_along_axis", arr, indices)
util.check_arraylike("take_along_axis", arr, indices)
index_dtype = dtypes.dtype(indices)
if not dtypes.issubdtype(index_dtype, integer):
raise TypeError("take_along_axis indices must be of integer type, got "
@ -4483,8 +4483,8 @@ def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
@util._wraps(np.gcd, module='numpy')
@jit
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
util._check_arraylike("gcd", x1, x2)
x1, x2 = util._promote_dtypes(x1, x2)
util.check_arraylike("gcd", x1, x2)
x1, x2 = util.promote_dtypes(x1, x2)
if not issubdtype(_dtype(x1), integer):
raise ValueError("Arguments to jax.numpy.gcd must be integers.")
x1, x2 = broadcast_arrays(x1, x2)
@ -4495,8 +4495,8 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
@util._wraps(np.lcm, module='numpy')
@jit
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
util._check_arraylike("lcm", x1, x2)
x1, x2 = util._promote_dtypes(x1, x2)
util.check_arraylike("lcm", x1, x2)
x1, x2 = util.promote_dtypes(x1, x2)
x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2)
if not issubdtype(_dtype(x1), integer):
raise ValueError("Arguments to jax.numpy.lcm must be integers.")
@ -4513,7 +4513,7 @@ def extract(condition: ArrayLike, arr: ArrayLike) -> Array:
@util._wraps(np.compress, skip_params=['out'])
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None,
out: None = None) -> Array:
util._check_arraylike("compress", condition, a)
util.check_arraylike("compress", condition, a)
condition_arr = asarray(condition).astype(bool)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
@ -4538,11 +4538,11 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
fweights: Optional[ArrayLike] = None,
aweights: Optional[ArrayLike] = None) -> Array:
if y is not None:
m, y = util._promote_args_inexact("cov", m, y)
m, y = util.promote_args_inexact("cov", m, y)
if y.ndim > 2:
raise ValueError("y has more than 2 dimensions")
else:
m, = util._promote_args_inexact("cov", m)
m, = util.promote_args_inexact("cov", m)
if m.ndim > 2:
raise ValueError("m has more than 2 dimensions") # same as numpy error
@ -4563,7 +4563,7 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
w: Optional[Array] = None
if fweights is not None:
util._check_arraylike("cov", fweights)
util.check_arraylike("cov", fweights)
if ndim(fweights) > 1:
raise RuntimeError("cannot handle multidimensional fweights")
if shape(fweights)[0] != X.shape[1]:
@ -4573,7 +4573,7 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
# Ensure positive fweights; note that numpy raises an error on negative fweights.
w = asarray(ufuncs.abs(fweights))
if aweights is not None:
util._check_arraylike("cov", aweights)
util.check_arraylike("cov", aweights)
if ndim(aweights) > 1:
raise RuntimeError("cannot handle multidimensional aweights")
if shape(aweights)[0] != X.shape[1]:
@ -4602,7 +4602,7 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
@util._wraps(np.corrcoef)
@partial(jit, static_argnames=('rowvar',))
def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array:
util._check_arraylike("corrcoef", x)
util.check_arraylike("corrcoef", x)
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
@ -4626,7 +4626,7 @@ def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util._check_arraylike("quantile", a, q)
util.check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
@ -4642,7 +4642,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ..
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util._check_arraylike("nanquantile", a, q)
util.check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
@ -4657,7 +4657,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = util._promote_dtypes_inexact(a)
a, = util.promote_dtypes_inexact(a)
keepdim = []
if issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
@ -4818,7 +4818,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
@partial(jit, static_argnames=('side', 'sorter', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: None = None, *, method: str = 'scan') -> Array:
util._check_arraylike("searchsorted", a, v)
util.check_arraylike("searchsorted", a, v)
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
"Expected one of ['left', 'right'].")
@ -4829,7 +4829,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
raise NotImplementedError("sorter is not implemented")
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
a, v = util._promote_dtypes(a, v)
a, v = util.promote_dtypes(a, v)
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
if len(a) == 0:
return zeros_like(v, dtype=dtype)
@ -4843,7 +4843,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
@util._wraps(np.digitize)
@partial(jit, static_argnames=('right',))
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
util._check_arraylike("digitize", x, bins)
util.check_arraylike("digitize", x, bins)
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
bins_arr = asarray(bins)
if bins_arr.ndim != 1:
@ -4867,7 +4867,7 @@ See the :func:`jax.lax.switch` documentation for more information.
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
funclist: List[Union[ArrayLike, Callable[..., Array]]],
*args, **kw) -> Array:
util._check_arraylike("piecewise", x)
util.check_arraylike("piecewise", x)
nc, nf = len(condlist), len(funclist)
if nf == nc + 1:
funclist = funclist[-1:] + funclist[:-1]
@ -4904,8 +4904,8 @@ def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util._check_arraylike("percentile", a, q)
q, = util._promote_dtypes_inexact(q)
util.check_arraylike("percentile", a, q)
q, = util.promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@ -4916,7 +4916,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util._check_arraylike("nanpercentile", a, q)
util.check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
@ -4927,7 +4927,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util._check_arraylike("median", a)
util.check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@ -4936,7 +4936,7 @@ def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util._check_arraylike("nanmedian", a)
util.check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')
@ -5012,7 +5012,7 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
if type is not None:
raise NotImplementedError("`type` argument of array.view() is not supported.")
util._check_arraylike("view", arr)
util.check_arraylike("view", arr)
arr = asarray(arr)
dtypes.check_user_dtype_supported(dtype, "view")

View File

@ -28,7 +28,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike
from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
@ -47,8 +47,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.cholesky", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.cholesky", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.cholesky(a)
@overload
@ -71,8 +71,8 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
_check_arraylike("jnp.linalg.svd", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.svd", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
s = lax.abs(v)
@ -95,8 +95,8 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
_check_arraylike("jnp.linalg.matrix_power", a)
arr, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.matrix_power", a)
arr, = promote_dtypes_inexact(jnp.asarray(a))
if arr.ndim < 2:
raise TypeError("{}-dimensional array given. Array must be at least "
@ -134,8 +134,8 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
@_wraps(np.linalg.matrix_rank)
@jit
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
_check_arraylike("jnp.linalg.matrix_rank", M)
M, = _promote_dtypes_inexact(jnp.asarray(M))
check_arraylike("jnp.linalg.matrix_rank", M)
M, = promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return (M != 0).any().astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
@ -197,8 +197,8 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.slogdet", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.slogdet", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
@ -269,8 +269,8 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
"""
a, = _promote_dtypes_inexact(jnp.asarray(a))
b, = _promote_dtypes_inexact(jnp.asarray(b))
a, = promote_dtypes_inexact(jnp.asarray(a))
b, = promote_dtypes_inexact(jnp.asarray(b))
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_ndims = len(a_shape)
@ -336,8 +336,8 @@ def _det_3x3(a: Array) -> Array:
@_wraps(np.linalg.det)
@jit
def det(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.det", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.det", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
return _det_2x2(a)
@ -369,8 +369,8 @@ backend. However eigendecomposition for symmetric/Hermitian matrices is
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
""")
def eig(a: ArrayLike) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.eig", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.eig", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
return w, v
@ -378,7 +378,7 @@ def eig(a: ArrayLike) -> Tuple[Array, Array]:
@_wraps(np.linalg.eigvals)
@jit
def eigvals(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.eigvals", a)
check_arraylike("jnp.linalg.eigvals", a)
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@ -387,7 +387,7 @@ def eigvals(a: ArrayLike) -> Array:
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
symmetrize_input: bool = True) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.eigh", a)
check_arraylike("jnp.linalg.eigh", a)
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
@ -396,7 +396,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None,
msg = f"UPLO must be one of None, 'L', or 'U', got {UPLO}"
raise ValueError(msg)
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v
@ -404,7 +404,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None,
@_wraps(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
_check_arraylike("jnp.linalg.eigvalsh", a)
check_arraylike("jnp.linalg.eigvalsh", a)
w, _ = eigh(a, UPLO)
return w
@ -420,7 +420,7 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
hermitian: bool = False) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a)
check_arraylike("jnp.linalg.pinv", a)
arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
@ -473,7 +473,7 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
@_wraps(np.linalg.inv)
@jit
def inv(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.inv", a)
check_arraylike("jnp.linalg.inv", a)
arr = jnp.asarray(a)
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
raise ValueError(
@ -487,8 +487,8 @@ def inv(a: ArrayLike) -> Array:
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
axis: Union[None, Tuple[int, ...], int] = None,
keepdims: bool = False) -> Array:
_check_arraylike("jnp.linalg.norm", x)
x, = _promote_dtypes_inexact(jnp.asarray(x))
check_arraylike("jnp.linalg.norm", x)
x, = promote_dtypes_inexact(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
@ -587,8 +587,8 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
_check_arraylike("jnp.linalg.qr", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
check_arraylike("jnp.linalg.qr", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
return _T(a), taus
@ -607,8 +607,8 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
@_wraps(np.linalg.solve)
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.solve", a, b)
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
return lax_linalg._solve(a, b)
@ -616,7 +616,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_dtypes_inexact(a, b)
a, b = promote_dtypes_inexact(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
@ -674,7 +674,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
"""))
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
_check_arraylike("jnp.linalg.lstsq", a, b)
check_arraylike("jnp.linalg.lstsq", a, b)
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)

View File

@ -31,7 +31,7 @@ from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
_check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps)
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, _wraps)
from jax._src.typing import Array, ArrayLike
@ -84,8 +84,8 @@ strip_zeros : bool, default=True
:func:`jax.jit` and other JAX transformations.
""")
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
_check_arraylike("roots", p)
p_arr = atleast_1d(*_promote_dtypes_inexact(p))
check_arraylike("roots", p)
p_arr = atleast_1d(*promote_dtypes_inexact(p))
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
@ -111,7 +111,7 @@ Also, it works best on rcond <= 10e-3 values.
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
full: bool = False, w: Optional[Array] = None, cov: bool = False
) -> Union[Array, Tuple[Array, ...]]:
_check_arraylike("polyfit", x, y)
check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
@ -136,8 +136,8 @@ def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
# apply weighting
if w is not None:
_check_arraylike("polyfit", w)
w, = _promote_dtypes_inexact(w)
check_arraylike("polyfit", w)
w, = promote_dtypes_inexact(w)
if w.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
@ -190,8 +190,8 @@ jax returns an array with a complex dtype in such cases.
@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros: Array) -> Array:
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
@ -224,8 +224,8 @@ compilation time.
""")
@partial(jit, static_argnames=['unroll'])
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
check_arraylike("polyval", p, x)
p, x = promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
@ -234,8 +234,8 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
@_wraps(np.polyadd)
@jit
def polyadd(a1: Array, a2: Array) -> Array:
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
check_arraylike("polyadd", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
@ -247,8 +247,8 @@ def polyadd(a1: Array, a2: Array) -> Array:
def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array:
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k_arr = _promote_dtypes_inexact(p, k)
check_arraylike("polyint", p, k)
p, k_arr = promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k_arr = atleast_1d(k_arr)
@ -268,9 +268,9 @@ def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array:
@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p: Array, m: int = 1) -> Array:
_check_arraylike("polyder", p)
check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
p, = promote_dtypes_inexact(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
@ -290,8 +290,8 @@ JAX backends. The result may lead to inconsistent output shapes when trim_leadin
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
_check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = _promote_dtypes_inexact(a1, a2)
check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
if len(a1_arr) == 0:
@ -302,8 +302,8 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]:
_check_arraylike("polydiv", u, v)
u_arr, v_arr = _promote_dtypes_inexact(u, v)
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
m = len(u_arr) - 1
n = len(v_arr) - 1
scale = 1. / v_arr[0]
@ -320,6 +320,6 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
@_wraps(np.polysub)
@jit
def polysub(a1: Array, a2: Array) -> Array:
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
check_arraylike("polysub", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
return polyadd(a1, -a2)

View File

@ -26,8 +26,8 @@ from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy.util import (
_broadcast_to, _check_arraylike, _complex_elem_type,
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
@ -52,7 +52,7 @@ def _isscalar(element: Any) -> bool:
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
# simplified version of jnp.moveaxis() for local use.
_check_arraylike("moveaxis", a)
check_arraylike("moveaxis", a)
a = _asarray(a)
source = _canonicalize_axis(source, np.ndim(a))
destination = _canonicalize_axis(destination, np.ndim(a))
@ -83,7 +83,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
_check_arraylike(name, a)
check_arraylike(name, a)
dtypes.check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
@ -180,7 +180,7 @@ def _cast_to_bool(operand: ArrayLike) -> Array:
return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array:
return _promote_dtypes_numeric(operand)[0]
return promote_dtypes_numeric(operand)[0]
def _ensure_optional_axes(x: Axis) -> Axis:
@ -326,7 +326,7 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("mean", a)
check_arraylike("mean", a)
dtypes.check_user_dtype_supported(dtype, "mean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
@ -365,8 +365,8 @@ def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None
def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
if weights is None: # Treat all weights as 1
_check_arraylike("average", a)
a, = _promote_dtypes_inexact(a)
check_arraylike("average", a)
a, = promote_dtypes_inexact(a)
avg = mean(a, axis=axis, keepdims=keepdims)
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
@ -375,8 +375,8 @@ def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = Non
else:
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
else:
_check_arraylike("average", a, weights)
a, weights = _promote_dtypes_inexact(a, weights)
check_arraylike("average", a, weights)
a, weights = promote_dtypes_inexact(a, weights)
a_shape = np.shape(a)
a_ndim = len(a_shape)
@ -427,7 +427,7 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("var", a)
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
@ -487,7 +487,7 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("std", a)
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
@ -502,7 +502,7 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
_check_arraylike("ptp", a)
check_arraylike("ptp", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
x = amax(a, axis=axis, keepdims=keepdims)
@ -514,7 +514,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
_check_arraylike("count_nonzero", a)
check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
@ -522,7 +522,7 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
_check_arraylike(name, a)
check_arraylike(name, a)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
@ -580,7 +580,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanmean", a)
check_arraylike("nanmean", a)
dtypes.check_user_dtype_supported(dtype, "nanmean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
@ -600,7 +600,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanvar", a)
check_arraylike("nanvar", a)
dtypes.check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
@ -631,7 +631,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None =
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanstd", a)
check_arraylike("nanstd", a)
dtypes.check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
@ -649,7 +649,7 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
@partial(api.jit, static_argnames=('axis', 'dtype'))
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike = None, out: None = None) -> Array:
_check_arraylike(np_reduction.__name__, a)
check_arraylike(np_reduction.__name__, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
f"is not supported.")

View File

@ -32,7 +32,7 @@ from jax._src.numpy.lax_numpy import (
sort, where, zeros)
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.typing import Array, ArrayLike
@ -48,7 +48,7 @@ def in1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, invert: bo
@partial(jit, static_argnames=('invert',))
def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
_check_arraylike("in1d", ar1, ar2)
check_arraylike("in1d", ar1, ar2)
ar1_flat = ravel(ar1)
ar2_flat = ravel(ar2)
# Note: an algorithm based on searchsorted has better scaling, but in practice
@ -80,7 +80,7 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
_check_arraylike("setdiff1d", ar1, ar2)
check_arraylike("setdiff1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
else:
@ -118,7 +118,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
value of the union."""))
def union1d(ar1: ArrayLike, ar2: ArrayLike,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
_check_arraylike("union1d", ar1, ar2)
check_arraylike("union1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
@ -132,7 +132,7 @@ In the JAX version, the input arrays are explicitly flattened regardless
of assume_unique value.
""")
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
_check_arraylike("setxor1d", ar1, ar2)
check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
@ -174,7 +174,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
@_wraps(np.intersect1d)
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
_check_arraylike("intersect1d", ar1, ar2)
check_arraylike("intersect1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
@ -326,7 +326,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: Optional[int] = None,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None):
_check_arraylike("unique", ar)
check_arraylike("unique", ar)
if size is None:
ar = core.concrete_or_error(None, ar,
"The error arose for the first argument of jnp.unique(). " + UNIQUE_SIZE_HINT)

View File

@ -29,9 +29,9 @@ from jax._src.api import jit, custom_jvp
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
_asarray, _check_arraylike, _promote_args, _promote_args_inexact,
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
_promote_shapes, _where, _wraps)
_asarray, check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
_lax_const = lax._const
@ -57,9 +57,9 @@ def _one_to_one_unop(
numpy_fn: Callable[..., Any], lax_fn: UnOp,
promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
if promote_to_inexact:
fn = lambda x, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
else:
fn = lambda x, /: lax_fn(*_promote_args(numpy_fn.__name__, x))
fn = lambda x, /: lax_fn(*promote_args(numpy_fn.__name__, x))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
@ -74,11 +74,11 @@ def _one_to_one_binop(
promote_to_inexact: bool = False, lax_doc: bool = False,
promote_to_numeric: bool = False) -> BinOp:
if promote_to_inexact:
fn = lambda x1, x2, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x1, x2))
elif promote_to_numeric:
fn = lambda x1, x2, /: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*promote_args_numeric(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2, /: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
@ -92,7 +92,7 @@ def _maybe_bool_binop(
numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
lax_doc: bool = False) -> BinOp:
def fn(x1, x2, /):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
@ -105,7 +105,7 @@ def _maybe_bool_binop(
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
def fn(x1, x2, /):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
@ -132,7 +132,7 @@ def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Un
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
for x in args)
return bitwise_op(*_promote_args(np_op.__name__, *args))
return bitwise_op(*promote_args(np_op.__name__, *args))
return op
@ -191,7 +191,7 @@ logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)
def arccosh(x: ArrayLike, /) -> Array:
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*_promote_args_inexact("arccosh", x))
out = lax.acosh(*promote_args_inexact("arccosh", x))
if dtypes.issubdtype(out.dtype, np.complexfloating):
out = _where(real(out) < 0, lax.neg(out), out)
return out
@ -200,7 +200,7 @@ def arccosh(x: ArrayLike, /) -> Array:
@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2)
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@ -209,7 +209,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
_check_arraylike('absolute', x)
check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
@ -218,7 +218,7 @@ abs = _wraps(np.abs, module='numpy')(absolute)
@_wraps(np.rint, module='numpy')
@jit
def rint(x: ArrayLike, /) -> Array:
_check_arraylike('rint', x)
check_arraylike('rint', x)
dtype = dtypes.dtype(x)
if dtype == bool or dtypes.issubdtype(dtype, np.integer):
return lax.convert_element_type(x, dtypes.float_)
@ -230,7 +230,7 @@ def rint(x: ArrayLike, /) -> Array:
@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike, /) -> Array:
_check_arraylike('sign', x)
check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
@ -242,7 +242,7 @@ def sign(x: ArrayLike, /) -> Array:
@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("copysign", x1, x2)
x1, x2 = promote_args_inexact("copysign", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
@ -251,7 +251,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.true_divide, module='numpy')
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
x1, x2 = promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
divide = true_divide
@ -260,7 +260,7 @@ divide = true_divide
@_wraps(np.floor_divide, module='numpy')
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric("floor_divide", x1, x2)
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
dtype = dtypes.dtype(x1)
if dtypes.issubdtype(dtype, np.integer):
quotient = lax.div(x1, x2)
@ -285,7 +285,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.divmod, module='numpy')
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
x1, x2 = _promote_args_numeric("divmod", x1, x2)
x1, x2 = promote_args_numeric("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
return floor_divide(x1, x2), remainder(x1, x2)
else:
@ -306,7 +306,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = _promote_args_numeric("power", x1, x2)
x1, x2 = promote_args_numeric("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
return lax.pow(x1, x2)
@ -328,7 +328,7 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("power", x1, x2)
check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
@ -339,7 +339,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
except TypeError:
pass
else:
x1, = _promote_dtypes_numeric(x1)
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)
return _power(x1, x2)
@ -348,7 +348,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.logaddexp, module='numpy')
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
x1, x2 = promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
@ -375,7 +375,7 @@ def _wrap_between(x, _a):
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
@ -386,7 +386,7 @@ def _logaddexp_jvp(primals, tangents):
@_wraps(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
@ -404,7 +404,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
@ -414,28 +414,28 @@ def _logaddexp2_jvp(primals, tangents):
@_wraps(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("log2", x)
x, = promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("log10", x)
x, = promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("exp2", x)
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit, module='numpy')
@jit
def signbit(x: ArrayLike, /) -> Array:
x, = _promote_args("signbit", x)
x, = promote_args("signbit", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.integer):
return lax.lt(x, _constant_like(x, 0))
@ -472,14 +472,14 @@ def _normalize_float(x):
@_wraps(np.ldexp, module='numpy')
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("ldexp", x1, x2)
check_arraylike("ldexp", x1, x2)
x1_dtype = dtypes.dtype(x1)
x2_dtype = dtypes.dtype(x2)
if (dtypes.issubdtype(x1_dtype, np.complexfloating)
or dtypes.issubdtype(x2_dtype, np.inexact)):
raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")
x1, x2 = _promote_shapes("ldexp", x1, x2)
x1, x2 = promote_shapes("ldexp", x1, x2)
dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype))
info = dtypes.finfo(dtype)
@ -521,8 +521,8 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.frexp, module='numpy')
@jit
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
_check_arraylike("frexp", x)
x, = _promote_dtypes_inexact(x)
check_arraylike("frexp", x)
x, = promote_dtypes_inexact(x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
raise TypeError("frexp does not support complex-valued inputs")
@ -545,7 +545,7 @@ def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
@_wraps(np.remainder, module='numpy')
@jit
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric("remainder", x1, x2)
x1, x2 = promote_args_numeric("remainder", x1, x2)
zero = _constant_like(x1, 0)
if dtypes.issubdtype(x2.dtype, np.integer):
x2 = _where(x2 == 0, lax._ones(x2), x2)
@ -560,31 +560,31 @@ mod = _wraps(np.mod, module='numpy')(remainder)
@_wraps(np.fmod, module='numpy')
@jit
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("fmod", x1, x2)
check_arraylike("fmod", x1, x2)
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
x2 = _where(x2 == 0, lax._ones(x2), x2)
return lax.rem(*_promote_args_numeric("fmod", x1, x2))
return lax.rem(*promote_args_numeric("fmod", x1, x2))
@_wraps(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
_check_arraylike("square", x)
x, = _promote_dtypes_numeric(x)
check_arraylike("square", x)
x, = promote_dtypes_numeric(x)
return lax.integer_pow(x, 2)
@_wraps(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("deg2rad", x)
x, = promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))
@_wraps(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("rad2deg", x)
x, = promote_args_inexact("rad2deg", x)
return lax.mul(x, _lax_const(x, 180 / np.pi))
@ -595,7 +595,7 @@ radians = deg2rad
@_wraps(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
_check_arraylike("conjugate", x)
check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
conj = conjugate
@ -603,21 +603,21 @@ conj = conjugate
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
_check_arraylike("imag", val)
check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@_wraps(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
_check_arraylike("real", val)
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
_check_arraylike("modf", x)
x, = _promote_dtypes_inexact(x)
check_arraylike("modf", x)
x, = promote_dtypes_inexact(x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
@ -627,7 +627,7 @@ def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
@_wraps(np.isfinite, module='numpy')
@jit
def isfinite(x: ArrayLike, /) -> Array:
_check_arraylike("isfinite", x)
check_arraylike("isfinite", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.is_finite(x)
@ -640,7 +640,7 @@ def isfinite(x: ArrayLike, /) -> Array:
@_wraps(np.isinf, module='numpy')
@jit
def isinf(x: ArrayLike, /) -> Array:
_check_arraylike("isinf", x)
check_arraylike("isinf", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.eq(lax.abs(x), _constant_like(x, np.inf))
@ -678,15 +678,15 @@ isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
@_wraps(np.isnan, module='numpy')
@jit
def isnan(x: ArrayLike, /) -> Array:
_check_arraylike("isnan", x)
check_arraylike("isnan", x)
return lax.ne(x, x)
@_wraps(np.heaviside, module='numpy')
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("heaviside", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
check_arraylike("heaviside", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
zero = _lax_const(x1, 0)
return _where(lax.lt(x1, zero), zero,
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
@ -695,8 +695,8 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("hypot", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
check_arraylike("hypot", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = maximum(x1, x2), minimum(x1, x2)
@ -706,16 +706,16 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.reciprocal, module='numpy')
@partial(jit, inline=True)
def reciprocal(x: ArrayLike, /) -> Array:
_check_arraylike("reciprocal", x)
x, = _promote_dtypes_inexact(x)
check_arraylike("reciprocal", x)
x, = promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)
@_wraps(np.sinc, update_doc=False)
@jit
def sinc(x: ArrayLike, /) -> Array:
_check_arraylike("sinc", x)
x, = _promote_dtypes_inexact(x)
check_arraylike("sinc", x)
x, = promote_dtypes_inexact(x)
eq_zero = lax.eq(x, _lax_const(x, 0))
pi_x = lax.mul(_lax_const(x, np.pi), x)
safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)

View File

@ -228,11 +228,11 @@ def _asarray(arr: ArrayLike) -> Array:
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
"""
_check_arraylike("_asarray", arr)
check_arraylike("_asarray", arr)
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax._convert_element_type(arr, dtype, weak_type)
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [_asarray(arg) for arg in args]
@ -273,7 +273,7 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
def promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
@ -284,7 +284,7 @@ def _promote_dtypes(*args: ArrayLike) -> List[Array]:
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
@ -295,7 +295,7 @@ def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
for x in args]
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a numeric (non-bool) type."""
@ -306,7 +306,7 @@ def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
for x in args]
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
def promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
@ -333,7 +333,7 @@ stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name: str, *args: Any):
def check_arraylike(fun_name: str, *args: Any):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
@ -356,26 +356,26 @@ def _check_no_float0s(fun_name: str, *args: Any):
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
return promote_shapes(fun_name, *promote_dtypes(*args))
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
_check_arraylike(fun_name, *args)
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))
return promote_shapes(fun_name, *promote_dtypes_numeric(*args))
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
return promote_shapes(fun_name, *promote_dtypes_inexact(*args))
@partial(api.jit, inline=True)
@ -391,7 +391,7 @@ def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape) # type: ignore[union-attr]
_check_arraylike("broadcast_to", arr)
check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, Array) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
@ -425,7 +425,7 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax._zero(condition))
x, y = _promote_dtypes(x, y)
x, y = promote_dtypes(x, y)
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)

View File

@ -28,7 +28,7 @@ from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import _check_arraylike, _promote_dtypes
from jax._src.numpy.util import check_arraylike, promote_dtypes
Array = Any
@ -100,7 +100,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
if core.is_empty_shape(indexer.slice_shape):
return x
x, y = _promote_dtypes(x, y)
x, y = promote_dtypes(x, y)
# Broadcast `y` to the slice output shape.
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
@ -157,7 +157,7 @@ def _segment_update(name: str,
bucket_size: Optional[int] = None,
reducer: Optional[Callable] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
_check_arraylike(name, data, segment_ids)
check_arraylike(name, data, segment_ids)
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
data = jnp.asarray(data)
segment_ids = jnp.asarray(segment_ids)

View File

@ -18,7 +18,7 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.reductions import _reduction_dims
from jax._src.numpy.util import _promote_args_inexact
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
import numpy as np
@ -67,10 +67,10 @@ def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] =
on the value of the ``return_sign`` argument.
"""
if b is not None:
a_arr, b_arr = _promote_args_inexact("logsumexp", a, b)
a_arr, b_arr = promote_args_inexact("logsumexp", a, b)
a_arr = jnp.where(b_arr != 0, a_arr, -jnp.inf)
else:
a_arr, = _promote_args_inexact("logsumexp", a)
a_arr, = promote_args_inexact("logsumexp", a)
b_arr = a_arr # for type checking
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = jnp.max(a_arr, axis=dims, keepdims=keepdims)

View File

@ -38,7 +38,7 @@ from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, _check_arraylike, _promote_dtypes_inexact
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis
@ -330,7 +330,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
if not jnp.issubdtype(dtype, np.integer):
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
_check_arraylike("randint", minval, maxval)
check_arraylike("randint", minval, maxval)
minval = jnp.asarray(minval)
maxval = jnp.asarray(maxval)
if not jnp.issubdtype(minval.dtype, np.integer):
@ -423,7 +423,7 @@ def permutation(key: KeyArray,
A shuffled version of x or array range
"""
key, _ = _check_prng_key(key)
_check_arraylike("permutation", x)
check_arraylike("permutation", x)
axis = canonicalize_axis(axis, np.ndim(x) or 1)
if not np.ndim(x):
if not np.issubdtype(lax.dtype(x), np.integer):
@ -500,7 +500,7 @@ def choice(key: KeyArray,
if not isinstance(shape, Sequence):
raise TypeError("shape argument of jax.random.choice must be a sequence, "
f"got {shape}")
_check_arraylike("choice", a)
check_arraylike("choice", a)
arr = jnp.asarray(a)
if arr.ndim == 0:
n_inputs = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
@ -523,8 +523,8 @@ def choice(key: KeyArray,
slices = (slice(None),) * axis + (slice(n_draws),)
result = permutation(key, n_inputs if arr.ndim == 0 else arr, axis)[slices]
else:
_check_arraylike("choice", p)
p_arr, = _promote_dtypes_inexact(p)
check_arraylike("choice", p)
p_arr, = promote_dtypes_inexact(p)
if p_arr.shape != (n_inputs,):
raise ValueError("p must be None or match the shape of a")
if replace:
@ -615,7 +615,7 @@ def multivariate_normal(key: KeyArray,
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
key, _ = _check_prng_key(key)
mean, cov = _promote_dtypes_inexact(mean, cov)
mean, cov = promote_dtypes_inexact(mean, cov)
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
if dtype is None:
@ -1328,7 +1328,7 @@ def categorical(key: KeyArray,
is not None, or else ``np.delete(logits.shape, axis)``.
"""
key, _ = _check_prng_key(key)
_check_arraylike("categorical", logits)
check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits)
if axis >= 0:

View File

@ -19,7 +19,7 @@ import textwrap
from jax import vmap
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _check_arraylike, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps, check_arraylike, promote_dtypes_inexact
_no_chkfinite_doc = textwrap.dedent("""
@ -30,10 +30,10 @@ because compiled JAX code cannot perform checks of array values at runtime
@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
def vq(obs, code_book, check_finite=True):
_check_arraylike("scipy.cluster.vq.vq", obs, code_book)
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
if obs.ndim != code_book.ndim:
raise ValueError("Observation and code_book should have the same rank")
obs, code_book = _promote_dtypes_inexact(obs, code_book)
obs, code_book = promote_dtypes_inexact(obs, code_book)
if obs.ndim == 1:
obs, code_book = obs[..., None], code_book[..., None]
if obs.ndim != 2:

View File

@ -19,11 +19,11 @@ import scipy.fft as osp_fft
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _wraps, _promote_dtypes_complex
from jax._src.numpy.util import _wraps, promote_dtypes_complex
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
N_arr, k = _promote_dtypes_complex(N, k)
N_arr, k = promote_dtypes_complex(N, k)
return jnp.exp(-.5j * jnp.pi * k / N_arr)
def _dct_interleave(x: Array, axis: int) -> Array:

View File

@ -29,8 +29,8 @@ from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.util import (
_check_arraylike, _wraps, _promote_dtypes, _promote_dtypes_inexact,
_promote_dtypes_complex)
check_arraylike, _wraps, promote_dtypes, promote_dtypes_inexact,
promote_dtypes_complex)
from jax._src.typing import Array, ArrayLike
@ -43,7 +43,7 @@ _no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Sc
@partial(jit, static_argnames=('lower',))
def _cholesky(a: ArrayLike, lower: bool) -> Array:
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False)
return l if lower else jnp.conj(_T(l))
@ -63,7 +63,7 @@ def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
@partial(jit, static_argnames=('lower',))
def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
c, b = _promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b))
c, b = promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b))
lax_linalg._check_solve_shapes(c, b)
b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
transpose_a=not lower, conjugate_a=not lower)
@ -90,7 +90,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array,
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, Tuple[Array, Array, Array]]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@overload
@ -151,7 +151,7 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
raise NotImplementedError(
"Only the eigvals=None case of eigh is implemented.")
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower)
if eigvals_only:
@ -218,7 +218,7 @@ def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Tuple[Array, Array]:
del overwrite_a, check_finite # unused
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, pivots, _ = lax_linalg.lu(a)
return lu, pivots
@ -245,7 +245,7 @@ def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array
@partial(jit, static_argnums=(1,))
def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, _, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
m, n = jnp.shape(a)
@ -298,7 +298,7 @@ def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Ar
full_matrices = False
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
a, = _promote_dtypes_inexact(jnp.asarray(a))
a, = promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return (r,)
@ -334,7 +334,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
if assume_a != 'pos':
return jnp.linalg.solve(a, b)
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
lax_linalg._check_solve_shapes(a, b)
# With custom_linear_solve, we can reuse the same factorization when
@ -382,7 +382,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: Union[int, str],
else:
raise ValueError(f"Invalid 'trans' value {trans}")
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
# lax_linalg.triangular_solve only supports matrix 'b's at the moment.
b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
@ -609,7 +609,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
def block_diag(*arrs: ArrayLike) -> Array:
if len(arrs) == 0:
arrs = cast(Tuple[ArrayLike], (jnp.zeros((1, 0)),))
arrs = cast(Tuple[ArrayLike], _promote_dtypes(*arrs))
arrs = cast(Tuple[ArrayLike], promote_dtypes(*arrs))
bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
if bad_shapes:
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
@ -940,7 +940,7 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
if T.shape[0] != Z.shape[0]:
raise ValueError(f"Input array shapes must match: Z: {Z.shape} vs. T: {T.shape}")
T, Z = _promote_dtypes_complex(T, Z)
T, Z = promote_dtypes_complex(T, Z)
eps = jnp.finfo(T.dtype).eps
N = T.shape[0]
@ -1020,10 +1020,10 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
@_wraps(scipy.linalg.toeplitz)
def toeplitz(c: ArrayLike, r: Optional[ArrayLike] = None) -> Array:
if r is None:
_check_arraylike("toeplitz", c)
check_arraylike("toeplitz", c)
r = jnp.conjugate(jnp.asarray(c))
else:
_check_arraylike("toeplitz", c, r)
check_arraylike("toeplitz", c, r)
c = jnp.asarray(c).flatten()
r = jnp.asarray(r).flatten()

View File

@ -15,7 +15,7 @@
from typing import NamedTuple, Union
from functools import partial
from jax._src.numpy.util import _promote_dtypes_inexact
from jax._src.numpy.util import promote_dtypes_inexact
import jax.numpy as jnp
import jax
from jax import lax
@ -272,7 +272,7 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
Returns: LineSearchResults
"""
xk, pk = _promote_dtypes_inexact(xk, pk)
xk, pk = promote_dtypes_inexact(xk, pk)
def restricted_func_and_grad(t):
t = jnp.array(t, dtype=pk.dtype)
phi, g = jax.value_and_grad(f)(xk + t * pk)

View File

@ -28,7 +28,7 @@ from jax._src import dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.numpy import linalg
from jax._src.numpy.util import (
_check_arraylike, _wraps, _promote_dtypes_inexact, _promote_dtypes_complex)
check_arraylike, _wraps, promote_dtypes_inexact, promote_dtypes_complex)
from jax._src.third_party.scipy import signal_helper
from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
@ -44,7 +44,7 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike)
raise ValueError("in1 and in2 must have the same number of dimensions")
if in1.size == 0 or in2.size == 0:
raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
in1, in2 = _promote_dtypes_inexact(in1, in2)
in1, in2 = promote_dtypes_inexact(in1, in2)
no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
@ -135,7 +135,7 @@ def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0,
raise NotImplementedError("overwrite_data argument not implemented.")
if type not in ['constant', 'linear']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data_arr, = _promote_dtypes_inexact(jnp.asarray(data))
data_arr, = promote_dtypes_inexact(jnp.asarray(data))
if type == 'constant':
return data_arr - data_arr.mean(axis, keepdims=True)
else:
@ -185,7 +185,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
# Apply window by multiplication
if jnp.iscomplexobj(win):
result, = _promote_dtypes_complex(result)
result, = promote_dtypes_complex(result)
result = win.reshape((1,) * len(batch_shape) + (1, nperseg)) * result
# Perform the fft on last axis. Zero-pads automatically
@ -267,15 +267,15 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
axis = canonicalize_axis(axis, x.ndim)
if y is None:
_check_arraylike('spectral_helper', x)
x, = _promote_dtypes_inexact(x)
check_arraylike('spectral_helper', x)
x, = promote_dtypes_inexact(x)
y_arr = x # place-holder for type checking
outershape = tuple_delete(x.shape, axis)
else:
if mode != 'psd':
raise ValueError("two-argument mode is available only when mode=='psd'")
_check_arraylike('spectral_helper', x, y)
x, y_arr = _promote_dtypes_inexact(x, y)
check_arraylike('spectral_helper', x, y)
x, y_arr = promote_dtypes_inexact(x, y)
if x.ndim != y_arr.ndim:
raise ValueError("two-arguments must have the same rank ({x.ndim} vs {y.ndim}).")
# Check if we can broadcast the outer axes together
@ -384,7 +384,7 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
raise ValueError(f'Unknown scaling: {scaling}')
if mode == 'stft':
scale = jnp.sqrt(scale)
scale, = _promote_dtypes_complex(scale)
scale, = promote_dtypes_complex(scale)
# Determine onesided/ two-sided
if return_onesided:
@ -513,7 +513,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
Returns:
An array with `(..., output_size)`-shape containing overlapped signal.
"""
_check_arraylike("_overlap_and_add", x)
check_arraylike("_overlap_and_add", x)
step_size = jax.core.concrete_or_error(int, step_size,
"step_size for overlap_and_add")
if x.ndim < 2:
@ -557,7 +557,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
boundary: bool = True, time_axis: int = -1,
freq_axis: int = -2) -> Tuple[Array, Array]:
# Input validation
_check_arraylike("istft", Zxx)
check_arraylike("istft", Zxx)
if Zxx.ndim < 2:
raise ValueError('Input stft must be at least 2d!')
freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)

View File

@ -29,7 +29,7 @@ from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _promote_args_inexact, _promote_dtypes_inexact
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
@ -38,7 +38,7 @@ from jax._src.typing import Array, ArrayLike
@_wraps(osp_special.gammaln, module='scipy.special')
def gammaln(x: ArrayLike) -> Array:
x, = _promote_args_inexact("gammaln", x)
x, = promote_args_inexact("gammaln", x)
return lax.lgamma(x)
@ -51,14 +51,14 @@ betaln = _wraps(
@_wraps(osp_special.betainc, module='scipy.special')
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
a, b, x = _promote_args_inexact("betainc", a, b, x)
a, b, x = promote_args_inexact("betainc", a, b, x)
return lax.betainc(a, b, x)
@_wraps(osp_special.digamma, module='scipy.special', lax_description="""\
The JAX version only accepts real-valued inputs.""")
def digamma(x: ArrayLike) -> Array:
x, = _promote_args_inexact("digamma", x)
x, = promote_args_inexact("digamma", x)
return lax.digamma(x)
ad.defjvp(
lax.digamma_p,
@ -67,39 +67,39 @@ ad.defjvp(
@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False)
def gammainc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = _promote_args_inexact("gammainc", a, x)
a, x = promote_args_inexact("gammainc", a, x)
return lax.igamma(a, x)
@_wraps(osp_special.gammaincc, module='scipy.special', update_doc=False)
def gammaincc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = _promote_args_inexact("gammaincc", a, x)
a, x = promote_args_inexact("gammaincc", a, x)
return lax.igammac(a, x)
@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"],
lax_description="Note that the JAX version does not support complex inputs.")
def erf(x: ArrayLike) -> Array:
x, = _promote_args_inexact("erf", x)
x, = promote_args_inexact("erf", x)
return lax.erf(x)
@_wraps(osp_special.erfc, module='scipy.special', update_doc=False)
def erfc(x: ArrayLike) -> Array:
x, = _promote_args_inexact("erfc", x)
x, = promote_args_inexact("erfc", x)
return lax.erfc(x)
@_wraps(osp_special.erfinv, module='scipy.special')
def erfinv(x: ArrayLike) -> Array:
x, = _promote_args_inexact("erfinv", x)
x, = promote_args_inexact("erfinv", x)
return lax.erf_inv(x)
@api.custom_jvp
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
def logit(x: ArrayLike) -> Array:
x, = _promote_args_inexact("logit", x)
x, = promote_args_inexact("logit", x)
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
logit.defjvps(
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
@ -107,7 +107,7 @@ logit.defjvps(
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
def expit(x: ArrayLike) -> Array:
x, = _promote_args_inexact("expit", x)
x, = promote_args_inexact("expit", x)
return lax.logistic(x)
@ -116,7 +116,7 @@ logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.lo
@_wraps(osp_special.xlogy, module='scipy.special')
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
x, y = _promote_args_inexact("xlogy", x, y)
x, y = promote_args_inexact("xlogy", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
@ -125,7 +125,7 @@ def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
x, y = _promote_args_inexact("xlog1py", x, y)
x, y = promote_args_inexact("xlog1py", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
@ -134,7 +134,7 @@ def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
@_wraps(osp_special.entr, module='scipy.special')
def entr(x: ArrayLike) -> Array:
x, = _promote_args_inexact("entr", x)
x, = promote_args_inexact("entr", x)
return lax.select(lax.lt(x, _lax_const(x, 0)),
lax.full_like(x, -np.inf),
lax.neg(xlogy(x, x)))
@ -143,7 +143,7 @@ def entr(x: ArrayLike) -> Array:
@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d_ = _promote_args_inexact("multigammaln", a, d)
a, d_ = promote_args_inexact("multigammaln", a, d)
constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_),
lax.sub(d_, _lax_const(a, 1))),
@ -185,7 +185,7 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
# Numerical Algorithms 69.2 (2015): 253-270.
# https://arxiv.org/abs/1309.2877 - formula (5)
# here we keep the same notation as in reference
s, a = _promote_args_inexact("zeta", x, q)
s, a = promote_args_inexact("zeta", x, q)
dtype = lax.dtype(a).type
s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
# precision ~ N, M
@ -209,7 +209,7 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False)
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
assert jnp.issubdtype(lax.dtype(n), jnp.integer)
n_arr, x_arr = _promote_args_inexact("polygamma", n, x)
n_arr, x_arr = promote_args_inexact("polygamma", n, x)
shape = lax.broadcast_shapes(n_arr.shape, x_arr.shape)
return _polygamma(jnp.broadcast_to(n_arr, shape), jnp.broadcast_to(x_arr, shape))
@ -631,22 +631,22 @@ def _norm_logpdf(x):
@_wraps(osp_special.i0e, module='scipy.special')
def i0e(x: ArrayLike) -> Array:
x, = _promote_args_inexact("i0e", x)
x, = promote_args_inexact("i0e", x)
return lax.bessel_i0e(x)
@_wraps(osp_special.i0, module='scipy.special')
def i0(x: ArrayLike) -> Array:
x, = _promote_args_inexact("i0", x)
x, = promote_args_inexact("i0", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
@_wraps(osp_special.i1e, module='scipy.special')
def i1e(x: ArrayLike) -> Array:
x, = _promote_args_inexact("i1e", x)
x, = promote_args_inexact("i1e", x)
return lax.bessel_i1e(x)
@_wraps(osp_special.i1, module='scipy.special')
def i1(x: ArrayLike) -> Array:
x, = _promote_args_inexact("i1", x)
x, = promote_args_inexact("i1", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
def _bessel_jn_scan_body_fun(carry, k):
@ -713,7 +713,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
ValueError if elements of array `z` are not float.
"""
z = jnp.asarray(z)
z, = _promote_dtypes_inexact(z)
z, = promote_dtypes_inexact(z)
z_dtype = lax.dtype(z)
if dtypes.issubdtype(z_dtype, complex):
raise ValueError("complex input not supported.")
@ -1365,7 +1365,7 @@ def _expi_neg(x: Array) -> Array:
@jit
@_wraps(osp_special.expi, module='scipy.special')
def expi(x: ArrayLike) -> Array:
x_arr, = _promote_args_inexact("expi", x)
x_arr, = promote_args_inexact("expi", x)
return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos])
@ -1484,7 +1484,7 @@ def _expn3(n: int, x: Array) -> Array:
@_wraps(osp_special.expn, module='scipy.special')
@jit
def expn(n: ArrayLike, x: ArrayLike) -> Array:
n, x = _promote_args_inexact("expn", n, x)
n, x = promote_args_inexact("expn", n, x)
_c = _lax_const
zero = _c(x, 0)
one = _c(x, 1)
@ -1521,7 +1521,7 @@ def expn_jvp(n, primals, tangents):
@_wraps(osp_special.exp1, module="scipy.special")
def exp1(x: ArrayLike, module='scipy.special') -> Array:
x, = _promote_args_inexact("exp1", x)
x, = promote_args_inexact("exp1", x)
# Casting becuase custom_jvp generic does not work correctly with mypy.
return cast(Array, expn(1, x))

View File

@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis
@ -35,7 +35,7 @@ Currently the only supported nan_policy is 'propagate'
""")
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
_check_arraylike("mode", a)
check_arraylike("mode", a)
x = jnp.atleast_1d(a)
if nan_policy not in ["propagate", "omit", "raise"]:
@ -100,7 +100,7 @@ def rankdata(
nan_policy: str = "propagate",
) -> Array:
_check_arraylike("rankdata", a)
check_arraylike("rankdata", a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(

View File

@ -18,14 +18,14 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, xlog1py
@_wraps(osp_stats.bernoulli.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = _promote_args_inexact("bernoulli.logpmf", k, p, loc)
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
@ -39,7 +39,7 @@ def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
k, p = _promote_args_inexact('bernoulli.cdf', k, p)
k, p = promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
conds = [
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
@ -52,7 +52,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array:
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
q, p = _promote_args_inexact('bernoulli.ppf', q, p)
q, p = promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)
return jnp.where(
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),

View File

@ -17,7 +17,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, xlogy, xlog1py
@ -25,7 +25,7 @@ from jax.scipy.special import betaln, xlogy, xlog1py
@_wraps(osp_stats.beta.logpdf, update_doc=False)
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
one = _lax_const(x, 1)
shape_term = lax.neg(betaln(a, b))
y = lax.div(lax.sub(x, loc), scale)

View File

@ -18,7 +18,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.scipy.special import betaln
from jax._src.typing import Array, ArrayLike
@ -27,7 +27,7 @@ from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.betabinom.logpmf."""
k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
y = lax.sub(lax.floor(k), loc)
one = _lax_const(y, 1)
zero = _lax_const(y, 0)

View File

@ -19,13 +19,13 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _promote_args_inexact
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.mul(pi, scale))

View File

@ -18,13 +18,13 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)

View File

@ -18,7 +18,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _promote_dtypes_inexact, _wraps
from jax._src.numpy.util import promote_dtypes_inexact, _wraps
from jax.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@ -30,7 +30,7 @@ def _is_simplex(x: Array) -> Array:
@_wraps(osp_stats.dirichlet.logpdf, update_doc=False)
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
return _logpdf(*_promote_dtypes_inexact(x, alpha))
return _logpdf(*promote_dtypes_inexact(x, alpha))
def _logpdf(x: Array, alpha: Array) -> Array:
if alpha.ndim != 1:

View File

@ -16,13 +16,13 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.expon.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale)
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
linear_term = lax.div(lax.sub(x, loc), scale)
log_probs = lax.neg(lax.add(linear_term, log_scale))

View File

@ -17,14 +17,14 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
one = _lax_const(x, 1)
y = lax.div(lax.sub(x, loc), scale)
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)

View File

@ -14,18 +14,18 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, p: ArrayLike) -> Array:
x, p = _promote_args_inexact("gennorm.logpdf", x, p)
x, p = promote_args_inexact("gennorm.logpdf", x, p)
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
def cdf(x: ArrayLike, p: ArrayLike) -> Array:
x, p = _promote_args_inexact("gennorm.cdf", x, p)
x, p = promote_args_inexact("gennorm.cdf", x, p)
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
@_wraps(osp_stats.gennorm.pdf, update_doc=False)

View File

@ -17,14 +17,14 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax.scipy.special import xlog1py
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.geom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = _promote_args_inexact("geom.logpmf", k, p, loc)
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)

View File

@ -21,7 +21,7 @@ import scipy.stats as osp_stats
import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
from jax._src.tree_util import register_pytree_node_class
from jax.scipy import linalg, special
@ -37,7 +37,7 @@ class gaussian_kde:
inv_cov: Any
def __init__(self, dataset, bw_method=None, weights=None):
_check_arraylike("gaussian_kde", dataset)
check_arraylike("gaussian_kde", dataset)
dataset = jnp.atleast_2d(dataset)
if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
raise NotImplementedError("gaussian_kde does not support complex data")
@ -46,8 +46,8 @@ class gaussian_kde:
d, n = dataset.shape
if weights is not None:
_check_arraylike("gaussian_kde", weights)
dataset, weights = _promote_dtypes_inexact(dataset, weights)
check_arraylike("gaussian_kde", weights)
dataset, weights = promote_dtypes_inexact(dataset, weights)
weights = jnp.atleast_1d(weights)
weights /= jnp.sum(weights)
if weights.ndim != 1:
@ -55,7 +55,7 @@ class gaussian_kde:
if len(weights) != n:
raise ValueError("`weights` input should be of length n")
else:
dataset, = _promote_dtypes_inexact(dataset)
dataset, = promote_dtypes_inexact(dataset)
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
self._setattr("dataset", dataset)
@ -115,7 +115,7 @@ class gaussian_kde:
@_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False)
def evaluate(self, points):
_check_arraylike("evaluate", points)
check_arraylike("evaluate", points)
points = self._reshape_points(points)
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
points.T, self.inv_cov)
@ -195,7 +195,7 @@ class gaussian_kde:
@_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False)
def logpdf(self, x):
_check_arraylike("logpdf", x)
check_arraylike("logpdf", x)
x = self._reshape_points(x)
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
x.T, self.inv_cov)
@ -238,7 +238,7 @@ def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
@partial(jit, static_argnums=0)
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
points, values, xi, precision = _promote_dtypes_inexact(
points, values, xi, precision = promote_dtypes_inexact(
points, values, xi, precision)
d = points.shape[1]

View File

@ -16,13 +16,13 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.laplace.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
two = _lax_const(x, 2)
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
@ -35,7 +35,7 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.laplace.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
half = _lax_const(x, 0.5)
one = _lax_const(x, 1)
zero = _lax_const(x, 0)

View File

@ -18,13 +18,13 @@ from jax.scipy.special import expit, logit
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x: ArrayLike) -> Array:
x, = _promote_args_inexact("logistic.logpdf", x)
x, = promote_args_inexact("logistic.logpdf", x)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))

View File

@ -16,7 +16,7 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact, _promote_args_numeric
from jax._src.numpy.util import _wraps, promote_args_inexact, promote_args_numeric
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@ -24,8 +24,8 @@ from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.multinomial.logpmf, update_doc=False)
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
"""JAX implementation of scipy.stats.multinomial.logpmf."""
p, = _promote_args_inexact("multinomial.logpmf", p)
x, n = _promote_args_numeric("multinomial.logpmf", x, n)
p, = promote_args_inexact("multinomial.logpmf", p)
x, n = promote_args_numeric("multinomial.logpmf", x, n)
if not jnp.issubdtype(x.dtype, jnp.integer):
raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}")
x = x.astype(p.dtype)

View File

@ -19,7 +19,7 @@ import scipy.stats as osp_stats
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
@ -29,7 +29,7 @@ In the JAX version, the `allow_singular` argument is not implemented.
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
if allow_singular is not None:
raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf")
x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
x, mean, cov = promote_dtypes_inexact(x, mean, cov)
if not mean.shape:
return (-1/2 * jnp.square(x - mean) / cov
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))

View File

@ -18,7 +18,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@ -26,7 +26,7 @@ from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.nbinom.logpmf."""
k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc)
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
one = _lax_const(k, 1)
y = lax.sub(k, loc)
comb_term = lax.sub(

View File

@ -20,13 +20,13 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy import special
@_wraps(osp_stats.norm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
scale_sqrd = lax.square(scale)
log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
@ -40,13 +40,13 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.norm.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("norm.cdf", x, loc, scale)
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
return special.ndtr(lax.div(lax.sub(x, loc), scale))
@_wraps(osp_stats.norm.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("norm.logcdf", x, loc, scale)
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
# Cast required because custom_jvp return type is broken.
return cast(Array, special.log_ndtr(lax.div(lax.sub(x, loc), scale)))

View File

@ -18,13 +18,13 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.pareto.logpdf, update_doc=False)
def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
one = _lax_const(x, 1)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.div(scale, b))

View File

@ -18,14 +18,14 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, gammaln, gammaincc
@_wraps(osp_stats.poisson.logpmf, update_doc=False)
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc)
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
@ -37,7 +37,7 @@ def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
@_wraps(osp_stats.poisson.cdf, update_doc=False)
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc)
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
p = gammaincc(jnp.floor(1 + x), mu)

View File

@ -18,13 +18,13 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.t.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
two = _lax_const(x, 2)
scaled_x = lax.div(lax.sub(x, loc), scale)
df_over_two = lax.div(df, two)

View File

@ -17,7 +17,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
@ -71,7 +71,7 @@ def _log_gauss_mass(a, b):
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
def logpdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
@ -87,7 +87,7 @@ def pdf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
@ -109,7 +109,7 @@ def sf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
def logcdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
from jax.numpy import where, inf, logical_or
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
@_wraps(osp_stats.uniform.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
return where(logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)),

View File

@ -17,12 +17,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.vonmises.logpdf, update_doc=False)
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
x, kappa = _promote_args_inexact('vonmises.logpdf', x, kappa)
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
zero = _lax_const(kappa, 0)
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan)

View File

@ -47,7 +47,7 @@ from jax._src.interpreters import pxla
from jax._src.config import (flags, bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.util import _promote_dtypes, _promote_dtypes_inexact
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.util import unzip2
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
@ -827,7 +827,7 @@ def promote_like_jnp(fun, inexact=False):
tests make an np reference implementation act more like an jnp
implementation.
"""
_promote = _promote_dtypes_inexact if inexact else _promote_dtypes
_promote = promote_dtypes_inexact if inexact else promote_dtypes
def wrapper(*args, **kw):
flat_args, tree = tree_flatten(args)
args = tree_unflatten(tree, _promote(*flat_args))

View File

@ -2,7 +2,7 @@ import numpy as np
import jax.numpy as jnp
import jax.numpy.linalg as la
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, _wraps
def _isEmpty2d(arr):
@ -41,7 +41,7 @@ def _assert2d(*arrays):
@_wraps(np.linalg.cond)
def cond(x, p=None):
_check_arraylike('jnp.linalg.cond', x)
check_arraylike('jnp.linalg.cond', x)
_assertNoEmpty2d(x)
if p in (None, 2):
s = la.svd(x, compute_uv=False)
@ -64,7 +64,7 @@ def cond(x, p=None):
@_wraps(np.linalg.tensorinv)
def tensorinv(a, ind=2):
_check_arraylike('jnp.linalg.tensorinv', a)
check_arraylike('jnp.linalg.tensorinv', a)
a = jnp.asarray(a)
oldshape = a.shape
prod = 1
@ -81,7 +81,7 @@ def tensorinv(a, ind=2):
@_wraps(np.linalg.tensorsolve)
def tensorsolve(a, b, axes=None):
_check_arraylike('jnp.linalg.tensorsolve', a, b)
check_arraylike('jnp.linalg.tensorsolve', a, b)
a = jnp.asarray(a)
b = jnp.asarray(b)
an = a.ndim
@ -110,7 +110,7 @@ def tensorsolve(a, b, axes=None):
@_wraps(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
_check_arraylike('jnp.linalg.multi_dot', *arrays)
check_arraylike('jnp.linalg.multi_dot', *arrays)
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:

View File

@ -1,7 +1,7 @@
from jax import lax
import jax.numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import _promote_args_inexact
from jax._src.numpy.util import promote_args_inexact
# Note: for mysterious reasons, annotating this leads to very slow mypy runs.
# def algdiv(a: ArrayLike, b: ArrayLike) -> Array:
@ -58,7 +58,7 @@ def betaln(a: ArrayLike, b: ArrayLike) -> Array:
.. _betaln:
https://github.com/scipy/scipy/blob/ef2dee592ba8fb900ff2308b9d1c79e4d6a0ad8b/scipy/special/cdflib/betaln.f
"""
a, b = _promote_args_inexact("betaln", a, b)
a, b = promote_args_inexact("betaln", a, b)
a, b = jnp.minimum(a, b), jnp.maximum(a, b)
small_b = lax.lgamma(a) + (lax.lgamma(b) - lax.lgamma(a + b))
large_b = lax.lgamma(a) + algdiv(a, b)

View File

@ -4,7 +4,7 @@ import scipy.interpolate as osp_interpolate
from jax.numpy import (asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
def _ndim_coords_from_arrays(points, ndim=None):
@ -21,7 +21,7 @@ def _ndim_coords_from_arrays(points, ndim=None):
for j, item in enumerate(p):
points = points.at[..., j].set(item)
else:
_check_arraylike("_ndim_coords_from_arrays", points)
check_arraylike("_ndim_coords_from_arrays", points)
points = asarray(points) # SciPy: asanyarray(points)
if points.ndim == 1:
if ndim is None:
@ -56,15 +56,15 @@ class RegularGridInterpolator:
if self.bounds_error:
raise NotImplementedError("`bounds_error` takes no effect under JIT")
_check_arraylike("RegularGridInterpolator", values)
check_arraylike("RegularGridInterpolator", values)
if len(points) > values.ndim:
ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
raise ValueError(ve)
values, = _promote_dtypes_inexact(values)
values, = promote_dtypes_inexact(values)
if fill_value is not None:
_check_arraylike("RegularGridInterpolator", fill_value)
check_arraylike("RegularGridInterpolator", fill_value)
fill_value = asarray(fill_value)
if not can_cast(fill_value.dtype, values.dtype, casting='same_kind'):
ve = "fill_value must be either 'None' or of a type compatible with values"
@ -72,7 +72,7 @@ class RegularGridInterpolator:
self.fill_value = fill_value
# TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
_check_arraylike("RegularGridInterpolator", *points)
check_arraylike("RegularGridInterpolator", *points)
self.grid = tuple(asarray(p) for p in points)
self.values = values

View File

@ -34,7 +34,7 @@ import jax.numpy as jnp
from jax._src import core
from jax import custom_derivatives
from jax import lax
from jax._src.numpy.util import _promote_dtypes_inexact
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.util import safe_map, safe_zip
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves, tree_map
@ -76,7 +76,7 @@ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
# Algorithm from:
# E. Hairer, S. P. Norsett G. Wanner,
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
y0, f0 = _promote_dtypes_inexact(y0, f0)
y0, f0 = promote_dtypes_inexact(y0, f0)
dtype = y0.dtype
scale = atol + jnp.abs(y0) * rtol

View File

@ -34,7 +34,7 @@ from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import gpu_sparse
from jax._src.numpy.util import _promote_dtypes
from jax._src.numpy.util import promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
import jax.numpy as jnp
@ -156,7 +156,7 @@ class COO(JAXSparse):
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
other = jnp.asarray(other)
data, other = _promote_dtypes(self.data, other)
data, other = promote_dtypes(self.data, other)
self_promoted = COO((data, self.row, self.col), **self._info._asdict())
if other.ndim == 1:
return coo_matvec(self_promoted, other)

View File

@ -34,7 +34,7 @@ from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.numpy.util import _promote_dtypes
from jax._src.numpy.util import promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
import jax.numpy as jnp
@ -117,7 +117,7 @@ class CSR(JAXSparse):
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
other = jnp.asarray(other)
data, other = _promote_dtypes(self.data, other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape)
elif other.ndim == 2:
@ -184,7 +184,7 @@ class CSC(JAXSparse):
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
other = jnp.asarray(other)
data, other = _promote_dtypes(self.data, other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(data, self.indices, self.indptr, other,
shape=self.shape[::-1], transpose=True)

View File

@ -25,7 +25,7 @@ from jax import lax
from jax import numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.util import _promote_dtypes_complex
from jax._src.numpy.util import promote_dtypes_complex
from jax.config import config
config.parse_flags_with_absl()
@ -175,7 +175,7 @@ class FftTest(jtu.JaxTestCase):
return jax.vmap(linear_func)(jnp.eye(size, size))
def func(x):
x, = _promote_dtypes_complex(x)
x, = promote_dtypes_complex(x)
return jnp.fft.irfft(jnp.concatenate([jnp.zeros_like(x, shape=1),
x[:2] + 1j*x[2:]]))