diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index 6fcc4e722..ecde2997e 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -33,6 +33,7 @@ from typing import NamedTuple import jax import jax._src.numpy.lax_numpy as jnp import jax._src.numpy.linalg as jnp_linalg +from jax._src.numpy import ufuncs from jax import lax from jax._src.lax import qdwh from jax._src.lax import linalg as lax_linalg @@ -163,7 +164,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh - return jnp.logical_and(still_counting, unconverged)[0] + return ufuncs.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args @@ -204,7 +205,7 @@ def split_spectrum(H, n, split_point, V0=None): H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) - rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) + rank = jnp.round(jnp.trace(ufuncs.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus @@ -359,7 +360,7 @@ def _eigh_work(H, n, termination_size=256): def default_case(agenda, blocks, eigenvectors): V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # TODO: Improve this? - split_point = jnp.nanmedian(_mask(jnp.diag(jnp.real(H)), (b,), jnp.nan)) + split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) H_minus, V_minus, H_plus, V_plus, rank = split_spectrum( H, b, split_point, V0=V) @@ -381,7 +382,7 @@ def _eigh_work(H, n, termination_size=256): norm = jnp_linalg.norm(H) tol = jnp.asarray(10 * jnp.finfo(H.dtype).eps / 2, dtype=norm.dtype) off_diag_norm = jnp_linalg.norm( - H - jnp.diag(jnp.diag(jnp.real(H)).astype(H.dtype))) + H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype))) # We also handle nearly-all-zero matrices matrices here. nearly_diagonal = (norm < tol) | (off_diag_norm / norm < tol) return lax.cond(nearly_diagonal, nearly_diagonal_case, default_case, @@ -450,7 +451,7 @@ def eigh(H, *, precision="float32", termination_size=256, n=None, n = N if n is None else n with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size) - eig_vals = _mask(jnp.real(eig_vals), (n,), jnp.nan) + eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan) if sort_eigenvalues: sort_idxs = jnp.argsort(eig_vals) eig_vals = eig_vals[sort_idxs] diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index f0d94f7bc..ce0754cb9 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -49,6 +49,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import reductions +from jax._src.numpy import ufuncs from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike @@ -375,7 +377,7 @@ def _solve(a: Array, b: Array) -> Array: return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2) -def _H(x: Array) -> Array: return jnp.conj(_T(x)) +def _H(x: Array) -> Array: return ufuncs.conj(_T(x)) def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 # primitives @@ -551,7 +553,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) - return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] + return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] eig_p = Primitive('eig') eig_p.multiple_results = True @@ -679,13 +681,13 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues): if lower: mask = jnp.tri(n, k=0, dtype=bool) else: - mask = jnp.logical_not(jnp.tri(n, k=-1, dtype=bool)) + mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool)) if dtypes.issubdtype(x.dtype, jnp.complexfloating): re = lax.select(mask, lax.real(x), _T(lax.real(x))) if lower: im_mask = jnp.tri(n, k=-1, dtype=bool) else: - im_mask = jnp.logical_not(jnp.tri(n, k=0, dtype=bool)) + im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool)) im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x))) im = lax.select(mask, im, -_T(im)) x = lax.complex(re, im) @@ -717,13 +719,13 @@ def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues): w = w_real.astype(a.dtype) eye_n = jnp.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. - Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n + Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) - dv = dot(v, jnp.multiply(Fmat, vdag_adot_v)) - dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) + dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v)) + dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw) def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues): @@ -789,7 +791,7 @@ def _triangular_solve_jvp_rule_a( g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a - g_a = jnp.conj(g_a) if conjugate_a else g_a + g_a = ufuncs.conj(g_a) if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) @@ -1029,9 +1031,9 @@ def _lu_unblocked(a): if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] - magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) + magnitude = ufuncs.abs(ufuncs.real(t)) + ufuncs.abs(ufuncs.imag(t)) else: - magnitude = jnp.abs(a[:, k]) + magnitude = ufuncs.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = pivot.at[k].set(i) a = a.at[[k, i],].set(a[[i, k],]) @@ -1553,7 +1555,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = Ut @ dA @ V - ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) + ds = ufuncs.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 30d34ca72..92d7f1c61 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -23,6 +23,7 @@ from jax._src.lib import xla_client from jax._src.util import safe_zip from jax._src.numpy.util import _check_arraylike, _wraps from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import ufuncs, reductions from jax._src.typing import Array, ArrayLike Shape = Sequence[int] @@ -31,9 +32,9 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: if norm == "backward": return jnp.array(1) elif norm == "ortho": - return jnp.sqrt(jnp.prod(s)) if func_name.startswith('i') else 1/jnp.sqrt(jnp.prod(s)) + return ufuncs.sqrt(reductions.prod(s)) if func_name.startswith('i') else 1/ufuncs.sqrt(reductions.prod(s)) elif norm == "forward": - return jnp.prod(s) if func_name.startswith('i') else 1/jnp.prod(s) + return reductions.prod(s) if func_name.startswith('i') else 1/reductions.prod(s) raise ValueError(f'Invalid norm value {norm}; should be "backward",' '"ortho" or "forward".') @@ -175,7 +176,7 @@ def irfft(a: ArrayLike, n: Optional[int] = None, @_wraps(np.fft.hfft) def hfft(a: ArrayLike, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None) -> Array: - conj_a = jnp.conj(a) + conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis, @@ -189,7 +190,7 @@ def ihfft(a: ArrayLike, n: Optional[int] = None, nn = arr.shape[axis] if n is None else n output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, arr, n=n, axis=axis, norm=norm) - return jnp.conj(output) * (1 / nn) + return ufuncs.conj(output) * (1 / nn) def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 03f3b6214..1039fd6c3 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -16,9 +16,11 @@ import abc from typing import Any, Iterable, List, Tuple, Union import jax -import jax._src.numpy.lax_numpy as jnp from jax._src import core from jax._src.numpy.util import _promote_dtypes +from jax._src.numpy.lax_numpy import ( + arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose +) from jax._src.typing import Array, ArrayLike import numpy as np @@ -35,9 +37,9 @@ def _make_1d_grid_from_slice(s: slice, op_name: str) -> Array: step = core.concrete_or_error(None, s.step, f"slice step of jnp.{op_name}") or 1 if np.iscomplex(step): - newobj = jnp.linspace(start, stop, int(abs(step))) + newobj = linspace(start, stop, int(abs(step))) else: - newobj = jnp.arange(start, stop, step) + newobj = arange(start, stop, step) return newobj @@ -53,12 +55,12 @@ class _IndexGrid(abc.ABC): output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) with jax.numpy_dtype_promotion('standard'): output = _promote_dtypes(*output) - output_arr = jnp.meshgrid(*output, indexing='ij', sparse=self.sparse) + output_arr = meshgrid(*output, indexing='ij', sparse=self.sparse) if self.sparse: return output_arr if len(output_arr) == 0: - return jnp.arange(0) - return jnp.stack(output_arr, 0) + return arange(0) + return stack(output_arr, 0) class _Mgrid(_IndexGrid): @@ -178,10 +180,10 @@ class _AxisConcat(abc.ABC): elif isinstance(item, str): raise ValueError("string directive must be placed at the beginning") else: - newobj = jnp.array(item, copy=False) + newobj = array(item, copy=False) item_ndim = newobj.ndim - newobj = jnp.array(newobj, copy=False, ndmin=ndmin) + newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1 and ndmin - item_ndim > 0: shape_obj = tuple(range(ndmin)) @@ -189,15 +191,15 @@ class _AxisConcat(abc.ABC): num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) - newobj = jnp.transpose(newobj, shape_obj) + newobj = transpose(newobj, shape_obj) output.append(newobj) - res = jnp.concatenate(tuple(output), axis=axis) + res = concatenate(tuple(output), axis=axis) if matrix != -1 and res.ndim == 1: # insert 2nd dim at axis 0 or 1 - res = jnp.expand_dims(res, matrix) + res = expand_dims(res, matrix) return res diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6e28ffba3..ed430d46c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -57,28 +57,9 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, from jax._src.lax import lax as lax_internal from jax._src.lib import pmap_lib from jax._src.lib import xla_client -from jax._src.numpy.reductions import ( # noqa: F401 - _ensure_optional_axes, _reduction_dims, - alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct, - max, mean, min, nancumsum, nancumprod, nanmax, nanmean, nanmin, nanprod, nanstd, - nansum, nanvar, prod, product, ptp, sometrue, std, sum, var, -) -from jax._src.numpy.ufuncs import ( # noqa: F401 - abs, absolute, add, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, - bitwise_and, bitwise_not, bitwise_or, bitwise_xor, cbrt, ceil, conj, conjugate, - copysign, cos, cosh, deg2rad, degrees, divide, divmod, equal, exp, exp2, expm1, - fabs, float_power, floor, floor_divide, fmod, frexp, greater, greater_equal, - heaviside, hypot, imag, invert, isfinite, isinf, isnan, isneginf, isposinf, - ldexp, left_shift, less, less_equal, log, log10, log1p, log2, logaddexp, logaddexp2, - logical_and, logical_not, logical_or, logical_xor, maximum, minimum, mod, modf, - multiply, negative, nextafter, not_equal, positive, power, rad2deg, radians, real, - reciprocal, remainder, right_shift, rint, sign, signbit, sin, sinc, sinh, sqrt, - square, subtract, tan, tanh, true_divide) -from jax._src.numpy.util import ( # noqa: F401 - _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, - _complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes, - _promote_dtypes_numeric, _promote_dtypes_inexact, _promote_shapes, - _register_stackable, _stackable, _where, _wraps) +from jax._src.numpy import reductions +from jax._src.numpy import ufuncs +from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape @@ -140,7 +121,7 @@ get_printoptions = np.get_printoptions printoptions = np.printoptions set_printoptions = np.set_printoptions -@_wraps(np.iscomplexobj) +@util._wraps(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: try: typ = x.dtype.type @@ -242,7 +223,7 @@ array_repr = np.array_repr save = np.save savez = np.savez -@_wraps(np.dtype) +@util._wraps(np.dtype) def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False, copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" @@ -305,7 +286,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) -@_wraps(np.load, update_doc=False) +@util._wraps(np.load, update_doc=False) def load(*args: Any, **kwargs: Any) -> Array: # The main purpose of this wrapper is to recover bfloat16 data types. # Note: this will only work for files created via np.save(), not np.savez(). @@ -322,21 +303,21 @@ def load(*args: Any, **kwargs: Any) -> Array: ### implementations of numpy functions in terms of lax -@_wraps(np.fmin, module='numpy') +@util._wraps(np.fmin, module='numpy') @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: - return where(less(x1, x2) | isnan(x2), x1, x2) + return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) -@_wraps(np.fmax, module='numpy') +@util._wraps(np.fmax, module='numpy') @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: - return where(greater(x1, x2) | isnan(x2), x1, x2) + return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) -@_wraps(np.issubdtype) +@util._wraps(np.issubdtype) def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) -@_wraps(np.isscalar) +@util._wraps(np.isscalar) def isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() @@ -344,20 +325,20 @@ def isscalar(element: Any) -> bool: iterable = np.iterable -@_wraps(np.result_type) +@util._wraps(np.result_type) def result_type(*args: ArrayLike) -> DType: return dtypes.result_type(*args) -@_wraps(np.trapz) +@util._wraps(np.trapz) @partial(jit, static_argnames=('axis',)) def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: if x is None: - _check_arraylike('trapz', y) - y_arr, = _promote_dtypes_inexact(y) + util._check_arraylike('trapz', y) + y_arr, = util._promote_dtypes_inexact(y) else: - _check_arraylike('trapz', y, x) - y_arr, x_arr = _promote_dtypes_inexact(y, x) + util._check_arraylike('trapz', y, x) + y_arr, x_arr = util._promote_dtypes_inexact(y, x) if x_arr.ndim == 1: dx = diff(x_arr) else: @@ -366,24 +347,24 @@ def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) -@_wraps(np.trunc, module='numpy') +@util._wraps(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: - _check_arraylike('trunc', x) - return where(lax.lt(x, _lax_const(x, 0)), ceil(x), floor(x)) + util._check_arraylike('trunc', x) + return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) @partial(jit, static_argnums=(2, 3, 4)) def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array: if ndim(x) != 1 or ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") - x, y = _promote_dtypes_inexact(x, y) + x, y = util._promote_dtypes_inexact(x, y) if len(x) == 0 or len(y) == 0: raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.") out_order = slice(None) if op == 'correlate': - y = conj(y) + y = ufuncs.conj(y) if len(x) < len(y): x, y = y, x out_order = slice(None, None, -1) @@ -406,30 +387,30 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> A return result[0, 0, out_order] -@_wraps(np.convolve, lax_description=_PRECISION_DOC) +@util._wraps(np.convolve, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None) -> Array: - _check_arraylike("convolve", a, v) + util._check_arraylike("convolve", a, v) return _conv(asarray(a), asarray(v), mode, 'convolve', precision) -@_wraps(np.correlate, lax_description=_PRECISION_DOC) +@util._wraps(np.correlate, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None) -> Array: - _check_arraylike("correlate", a, v) + util._check_arraylike("correlate", a, v) return _conv(asarray(a), asarray(v), mode, 'correlate', precision) -@_wraps(np.histogram_bin_edges) +@util._wraps(np.histogram_bin_edges) def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: Union[None, Array, Sequence[ArrayLike]] = None, weights: Optional[ArrayLike] = None) -> Array: del weights # unused, because string bins is not supported. if isinstance(bins, str): raise NotImplementedError("string values for `bins` not implemented.") - _check_arraylike("histogram_bin_edges", a, bins) + util._check_arraylike("histogram_bin_edges", a, bins) arr = ravel(a) dtype = dtypes.to_inexact_dtype(arr.dtype) if _ndim(bins) == 1: @@ -442,26 +423,26 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range = asarray(range, dtype=dtype) if shape(range) != (2,): raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}") - range = (where(ptp(range) == 0, range[0] - 0.5, range[0]), - where(ptp(range) == 0, range[1] + 0.5, range[1])) + range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]), + where(reductions.ptp(range) == 0, range[1] + 0.5, range[1])) assert range is not None return linspace(range[0], range[1], bins_int + 1, dtype=dtype) -@_wraps(np.histogram) +@util._wraps(np.histogram) def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Optional[Sequence[ArrayLike]] = None, weights: Optional[ArrayLike] = None, density: Optional[bool] = None) -> Tuple[Array, Array]: if weights is None: - _check_arraylike("histogram", a, bins) - a = ravel(*_promote_dtypes_inexact(a)) + util._check_arraylike("histogram", a, bins) + a = ravel(*util._promote_dtypes_inexact(a)) weights = ones_like(a) else: - _check_arraylike("histogram", a, bins, weights) + util._check_arraylike("histogram", a, bins, weights) if shape(a) != shape(weights): raise ValueError("weights should have the same shape as a.") - a, weights = map(ravel, _promote_dtypes_inexact(a, weights)) + a, weights = map(ravel, util._promote_dtypes_inexact(a, weights)) bin_edges = histogram_bin_edges(a, bins, range, weights) bin_idx = searchsorted(bin_edges, a, side='right') @@ -472,12 +453,12 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, counts = counts / bin_widths / counts.sum() return counts, bin_edges -@_wraps(np.histogram2d) +@util._wraps(np.histogram2d) def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None, weights: Optional[ArrayLike] = None, density: Optional[bool] = None) -> Tuple[Array, Array, Array]: - _check_arraylike("histogram2d", x, y) + util._check_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -491,19 +472,19 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLik hist, edges = histogramdd(sample, bins, range, weights, density) return hist, edges[0], edges[1] -@_wraps(np.histogramdd) +@util._wraps(np.histogramdd) def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None, weights: Optional[ArrayLike] = None, density: Optional[bool] = None) -> Tuple[Array, List[Array]]: if weights is None: - _check_arraylike("histogramdd", sample) - sample, = _promote_dtypes_inexact(sample) + util._check_arraylike("histogramdd", sample) + sample, = util._promote_dtypes_inexact(sample) else: - _check_arraylike("histogramdd", sample, weights) + util._check_arraylike("histogramdd", sample, weights) if shape(weights) != shape(sample)[:1]: raise ValueError("should have one weight for each sample.") - sample, weights = _promote_dtypes_inexact(sample, weights) + sample, weights = util._promote_dtypes_inexact(sample, weights) N, D = shape(sample) if range is not None and ( @@ -555,18 +536,18 @@ The JAX version of this function may in some cases return a copy rather than a view of the input. """ -@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC) def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array: - _stackable(a) or _check_arraylike("transpose", a) + util._stackable(a) or util._check_arraylike("transpose", a) axes_ = list(range(ndim(a))[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] return lax.transpose(a, axes_) -@_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array: - _check_arraylike("rot90", m) + util._check_arraylike("rot90", m) ax1, ax2 = axes ax1 = _canonicalize_axis(ax1, ndim(m)) ax2 = _canonicalize_axis(ax2, ndim(m)) @@ -586,10 +567,10 @@ def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) -@_wraps(np.flip, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: - _check_arraylike("flip", m) - return _flip(asarray(m), _ensure_optional_axes(axis)) + util._check_arraylike("flip", m) + return _flip(asarray(m), reductions._ensure_optional_axes(axis)) @partial(jit, static_argnames=('axis',)) def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: @@ -599,34 +580,34 @@ def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) -@_wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: - _check_arraylike("fliplr", m) + util._check_arraylike("fliplr", m) return _flip(asarray(m), 1) -@_wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: - _check_arraylike("flipud", m) + util._check_arraylike("flipud", m) return _flip(asarray(m), 0) -@_wraps(np.iscomplex) +@util._wraps(np.iscomplex) @jit def iscomplex(x: ArrayLike) -> Array: - i = imag(x) + i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) -@_wraps(np.isreal) +@util._wraps(np.isreal) @jit def isreal(x: ArrayLike) -> Array: - i = imag(x) + i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) -@_wraps(np.angle) +@util._wraps(np.angle) @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: - re = real(z) - im = imag(z) + re = ufuncs.real(z) + im = ufuncs.imag(z) dtype = _dtype(re) if not issubdtype(dtype, inexact) or ( issubdtype(_dtype(z), floating) and ndim(z) == 0): @@ -634,15 +615,15 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: re = lax.convert_element_type(re, dtype) im = lax.convert_element_type(im, dtype) result = lax.atan2(im, re) - return degrees(result) if deg else result + return ufuncs.degrees(result) if deg else result -@_wraps(np.diff) +@util._wraps(np.diff) @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: Optional[ArrayLike] = None, append: Optional[ArrayLike] = None) -> Array: - _check_arraylike("diff", a) + util._check_arraylike("diff", a) arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff") @@ -658,7 +639,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, combined: List[Array] = [] if prepend is not None: - _check_arraylike("diff", prepend) + util._check_arraylike("diff", prepend) if isscalar(prepend): shape = list(arr.shape) shape[axis] = 1 @@ -668,7 +649,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, combined.append(arr) if append is not None: - _check_arraylike("diff", append) + util._check_arraylike("diff", append) if isscalar(append): shape = list(arr.shape) shape[axis] = 1 @@ -685,7 +666,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, slice1_tuple = tuple(slice1) slice2_tuple = tuple(slice2) - op = not_equal if arr.dtype == np.bool_ else subtract + op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract for _ in range(n): arr = op(arr[slice1_tuple], arr[slice2_tuple]) @@ -697,30 +678,30 @@ issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary`` loses precision. """ -@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) +@util._wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None, to_begin: Optional[ArrayLike] = None) -> Array: - _check_arraylike("ediff1d", ary) + util._check_arraylike("ediff1d", ary) arr = ravel(ary) result = lax.sub(arr[1:], arr[:-1]) if to_begin is not None: - _check_arraylike("ediff1d", to_begin) + util._check_arraylike("ediff1d", to_begin) result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result)) if to_end is not None: - _check_arraylike("ediff1d", to_end) + util._check_arraylike("ediff1d", to_end) result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result -@_wraps(np.gradient, skip_params=['edge_order']) +@util._wraps(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=('axis', 'edge_order')) def gradient(f: ArrayLike, *varargs: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, edge_order: Optional[int] = None) -> Union[Array, List[Array]]: if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") - a, *spacing = _promote_args_inexact("gradient", f, *varargs) + a, *spacing = util._promote_args_inexact("gradient", f, *varargs) def gradient_along_axis(a, h, axis): sliced = partial(lax.slice_in_dim, a, axis=axis) @@ -762,14 +743,14 @@ def gradient(f: ArrayLike, *varargs: ArrayLike, return a_grad[0] if len(axis_tuple) == 1 else a_grad -@_wraps(np.isrealobj) +@util._wraps(np.isrealobj) def isrealobj(x: Any) -> bool: return not iscomplexobj(x) -@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) def reshape(a: ArrayLike, newshape: Union[DimSize, Shape], order: str = "C") -> Array: - _stackable(a) or _check_arraylike("reshape", a) + util._stackable(a) or util._check_arraylike("reshape", a) try: # forward to method for ndarrays return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] @@ -833,21 +814,21 @@ def _transpose(a: Array, *args: Any) -> Array: axis = _ensure_index_tuple(args) return transpose(a, axis) -@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: - _stackable(a) or _check_arraylike("ravel", a) + util._stackable(a) or util._check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") return reshape(a, (size(a),), order) -@_wraps(np.ravel_multi_index) +@util._wraps(np.ravel_multi_index) def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...], mode: str = 'raise', order: str = 'C') -> Array: assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) - _check_arraylike("ravel_multi_index", *multi_index) + util._check_arraylike("ravel_multi_index", *multi_index) multi_index_arr = [asarray(i) for i in multi_index] for index in multi_index_arr: if mode == 'raise': @@ -857,7 +838,7 @@ def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...], if not issubdtype(_dtype(index), integer): raise TypeError("only int indices permitted") if mode == "raise": - if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)): + if _any(reductions.any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)): raise ValueError("invalid entry in coordinates array") elif mode == "clip": multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)] @@ -885,9 +866,9 @@ Unlike numpy's implementation of unravel_index, negative indices are accepted and out-of-bounds indices are clipped into the valid range. """ -@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) +@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]: - _check_arraylike("unravel_index", indices) + util._check_arraylike("unravel_index", indices) indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a # tuple of weakly-typed values, and asarray() would strip these weak types. @@ -899,16 +880,16 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]: raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") out_indices = [0] * len(shape) for i, s in reversed(list(enumerate(shape))): - indices_arr, out_indices[i] = divmod(indices_arr, s) + indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s) oob_pos = indices_arr > 0 oob_neg = indices_arr < -1 return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices)) -@_wraps(np.resize) +@util._wraps(np.resize) @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: - _check_arraylike("resize", a) + util._check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) if _any(dim_length < 0 for dim_length in new_shape): @@ -925,9 +906,9 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) -@_wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC) def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: - _check_arraylike("squeeze", a) + util._check_arraylike("squeeze", a) return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) @partial(jit, static_argnames=('axis',), inline=True) @@ -941,28 +922,28 @@ def _squeeze(a: Array, axis: Tuple[int]) -> Array: return lax.squeeze(a, axis) -@_wraps(np.expand_dims) +@util._wraps(np.expand_dims) def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: - _stackable(a) or _check_arraylike("expand_dims", a) + util._stackable(a) or util._check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) if hasattr(a, "expand_dims"): return a.expand_dims(axis) # type: ignore return lax.expand_dims(a, axis) -@_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: - _check_arraylike("swapaxes", a) + util._check_arraylike("swapaxes", a) perm = np.arange(ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) -@_wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], destination: Union[int, Sequence[int]]) -> Array: - _check_arraylike("moveaxis", a) + util._check_arraylike("moveaxis", a) return _moveaxis(asarray(a), _ensure_index_tuple(source), _ensure_index_tuple(destination)) @@ -979,57 +960,57 @@ def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) - return lax.transpose(a, perm) -@_wraps(np.isclose) +@util._wraps(np.isclose) @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: - a, b = _promote_args("isclose", a, b) + a, b = util._promote_args("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, inexact): if issubdtype(dtype, complexfloating): - dtype = _complex_elem_type(dtype) + dtype = util._complex_elem_type(dtype) rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) out = lax.le( lax.abs(lax.sub(a, b)), lax.add(atol, lax.mul(rtol, lax.abs(b)))) # This corrects the comparisons for infinite and nan values - a_inf = isinf(a) - b_inf = isinf(b) - any_inf = logical_or(a_inf, b_inf) - both_inf = logical_and(a_inf, b_inf) + a_inf = ufuncs.isinf(a) + b_inf = ufuncs.isinf(b) + any_inf = ufuncs.logical_or(a_inf, b_inf) + both_inf = ufuncs.logical_and(a_inf, b_inf) # Make all elements where either a or b are infinite to False - out = logical_and(out, logical_not(any_inf)) + out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) # Make all elements where both a or b are the same inf to True same_value = lax.eq(a, b) - same_inf = logical_and(both_inf, same_value) - out = logical_or(out, same_inf) + same_inf = ufuncs.logical_and(both_inf, same_value) + out = ufuncs.logical_or(out, same_inf) # Make all elements where either a or b is NaN to False - a_nan = isnan(a) - b_nan = isnan(b) - any_nan = logical_or(a_nan, b_nan) - out = logical_and(out, logical_not(any_nan)) + a_nan = ufuncs.isnan(a) + b_nan = ufuncs.isnan(b) + any_nan = ufuncs.logical_or(a_nan, b_nan) + out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) if equal_nan: # Make all elements where both a and b is NaN to True - both_nan = logical_and(a_nan, b_nan) - out = logical_or(out, both_nan) + both_nan = ufuncs.logical_and(a_nan, b_nan) + out = ufuncs.logical_or(out, both_nan) return out else: return lax.eq(a, b) -@_wraps(np.interp) +@util._wraps(np.interp) @jit def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: Optional[ArrayLike] = None, right: Optional[ArrayLike] = None, period: Optional[ArrayLike] = None) -> Array: - _check_arraylike("interp", x, xp, fp) + util._check_arraylike("interp", x, xp, fp) if shape(xp) != shape(fp) or ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") - x_arr, xp_arr = _promote_dtypes_inexact(x, xp) - fp_arr, = _promote_dtypes_inexact(fp) + x_arr, xp_arr = util._promote_dtypes_inexact(x, xp) + fp_arr, = util._promote_dtypes_inexact(fp) del x, xp, fp if dtypes.issubdtype(x_arr.dtype, np.complexfloating): @@ -1038,7 +1019,7 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, if period is not None: if ndim(period) != 0: raise ValueError(f"period must be a scalar; got {period}") - period = abs(period) + period = ufuncs.abs(period) x_arr = x_arr % period xp_arr = xp_arr % period xp_arr, fp_arr = lax.sort_key_val(xp_arr, fp_arr) @@ -1081,7 +1062,7 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = None, fill_value: Union[None, Array, Tuple[ArrayLike]] = None ) -> Union[Array, Tuple[Array, ...]]: ... -@_wraps(np.where, +@util._wraps(np.where, lax_description=_dedent(""" At present, JAX does not support JIT-compilation of the single-argument form of :py:func:`jax.numpy.where` because its output shape is data-dependent. The @@ -1110,23 +1091,23 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = None, fill_value: Union[None, Array, Tuple[ArrayLike]] = None ) -> Union[Array, Tuple[Array, ...]]: if x is None and y is None: - _check_arraylike("where", condition) + util._check_arraylike("where", condition) return nonzero(condition, size=size, fill_value=fill_value) else: - _check_arraylike("where", condition, x, y) + util._check_arraylike("where", condition, x, y) if size is not None or fill_value is not None: raise ValueError("size and fill_value arguments cannot be used in three-term where function.") - return _where(condition, x, y) + return util._where(condition, x, y) -@_wraps(np.select) +@util._wraps(np.select) def select(condlist, choicelist, default=0): if len(condlist) != len(choicelist): msg = "condlist must have length equal to choicelist ({} vs {})" raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") - choices = _promote_dtypes(default, *choicelist) + choices = util._promote_dtypes(default, *choicelist) choicelist = choices[1:] output = choices[0] for cond, choice in zip(condlist[::-1], choicelist[::-1]): @@ -1134,7 +1115,7 @@ def select(condlist, choicelist, default=0): return output -@_wraps(np.bincount, lax_description="""\ +@util._wraps(np.bincount, lax_description="""\ Jax adds the optional `length` parameter which specifies the output length, and defaults to ``x.max() + 1``. It must be specified for bincount to be compiled with non-static operands. Values larger than the specified length will be discarded. @@ -1145,7 +1126,7 @@ negative values, ``jax.numpy.bincount`` clips negative values to zero. """) def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None, minlength: int = 0, *, length: Optional[int] = None) -> Array: - _check_arraylike("bincount", x) + util._check_arraylike("bincount", x) if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") if ndim(x) != 1: @@ -1173,7 +1154,7 @@ def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ... def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...] ) -> Tuple[Union[int, core.Tracer], ...]: ... -@_wraps(getattr(np, "broadcast_shapes", None)) +@util._wraps(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): if not shapes: return () @@ -1181,23 +1162,23 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) -@_wraps(np.broadcast_arrays, lax_description="""\ +@util._wraps(np.broadcast_arrays, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_arrays(*args: ArrayLike) -> List[Array]: - return _broadcast_arrays(*args) + return util._broadcast_arrays(*args) -@_wraps(np.broadcast_to, lax_description="""\ +@util._wraps(np.broadcast_to, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_to(array: ArrayLike, shape: Shape) -> Array: - return _broadcast_to(array, shape) + return util._broadcast_to(array, shape) def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]: - _check_arraylike(op, ary) + util._check_arraylike(op, ary) ary = asarray(ary) axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`") size = ary.shape[axis] @@ -1233,16 +1214,16 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])] -@_wraps(np.split, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC) def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]: return _split("split", ary, indices_or_sections, axis=axis) def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]: - @_wraps(getattr(np, op), update_doc=False) + @util._wraps(getattr(np, op), update_doc=False) def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]: # for 1-D array, hsplit becomes vsplit nonlocal axis - _check_arraylike(op, ary) + util._check_arraylike(op, ary) a = asarray(ary) if axis == 1 and len(a.shape) == 1: axis = 0 @@ -1253,29 +1234,29 @@ vsplit = _split_on_axis("vsplit", axis=0) hsplit = _split_on_axis("hsplit", axis=1) dsplit = _split_on_axis("dsplit", axis=2) -@_wraps(np.array_split) +@util._wraps(np.array_split) def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) -@_wraps(np.clip, skip_params=['out']) +@util._wraps(np.clip, skip_params=['out']) @jit def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None, a_max: Optional[ArrayLike] = None, out: None = None) -> Array: - _check_arraylike("clip", a) + util._check_arraylike("clip", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") if a_min is None and a_max is None: raise ValueError("At most one of a_min and a_max may be None") if a_min is not None: - a = maximum(a_min, a) + a = ufuncs.maximum(a_min, a) if a_max is not None: - a = minimum(a_max, a) + a = ufuncs.minimum(a_max, a) return asarray(a) -@_wraps(np.around, skip_params=['out']) +@util._wraps(np.around, skip_params=['out']) @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: - _check_arraylike("round", a) + util._check_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: raise NotImplementedError("The 'out' argument to jnp.round is not supported.") @@ -1308,23 +1289,23 @@ around = round round_ = round -@_wraps(np.fix, skip_params=['out']) +@util._wraps(np.fix, skip_params=['out']) @jit def fix(x: ArrayLike, out: None = None) -> Array: - _check_arraylike("fix", x) + util._check_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") zero = _lax_const(x, 0) - return where(lax.ge(x, zero), floor(x), ceil(x)) + return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) -@_wraps(np.nan_to_num) +@util._wraps(np.nan_to_num) @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: Optional[ArrayLike] = None, neginf: Optional[ArrayLike] = None) -> Array: del copy - _check_arraylike("nan_to_num", x) + util._check_arraylike("nan_to_num", x) dtype = _dtype(x) if not issubdtype(dtype, inexact): return asarray(x) @@ -1335,18 +1316,18 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, info = finfo(dtypes.canonicalize_dtype(dtype)) posinf = info.max if posinf is None else posinf neginf = info.min if neginf is None else neginf - out = where(isnan(x), asarray(nan, dtype=dtype), x) - out = where(isposinf(out), asarray(posinf, dtype=dtype), out) - out = where(isneginf(out), asarray(neginf, dtype=dtype), out) + out = where(ufuncs.isnan(x), asarray(nan, dtype=dtype), x) + out = where(ufuncs.isposinf(out), asarray(posinf, dtype=dtype), out) + out = where(ufuncs.isneginf(out), asarray(neginf, dtype=dtype), out) return out -@_wraps(np.allclose) +@util._wraps(np.allclose) @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: - _check_arraylike("allclose", a, b) - return all(isclose(a, b, rtol, atol, equal_nan)) + util._check_arraylike("allclose", a, b) + return reductions.all(isclose(a, b, rtol, atol, equal_nan)) _NONZERO_DOC = """\ @@ -1364,11 +1345,11 @@ fill_value : array_like, optional remaining elements will be filled with ``fill_value``, which defaults to zero. """ -@_wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) +@util._wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def nonzero(a: ArrayLike, *, size: Optional[int] = None, fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None ) -> Tuple[Array, ...]: - _check_arraylike("nonzero", a) + util._check_arraylike("nonzero", a) arr = atleast_1d(a) del a mask = arr if arr.dtype == bool else (arr != 0) @@ -1380,7 +1361,7 @@ def nonzero(a: ArrayLike, *, size: Optional[int] = None, "to use jnp.nonzero within JAX transformations.") if arr.size == 0 or size == 0: return tuple(zeros(size, int) for dim in arr.shape) - flat_indices = cumsum(bincount(cumsum(mask), length=size)) + flat_indices = reductions.cumsum(bincount(reductions.cumsum(mask), length=size)) strides = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(int_) out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if size is not None and fill_value is not None: @@ -1391,34 +1372,34 @@ def nonzero(a: ArrayLike, *, size: Optional[int] = None, out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) return out -@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) +@util._wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def flatnonzero(a: ArrayLike, *, size: Optional[int] = None, fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array: return nonzero(ravel(a), size=size, fill_value=fill_value)[0] -@_wraps(np.unwrap) +@util._wraps(np.unwrap) @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: - _check_arraylike("unwrap", p) + util._check_arraylike("unwrap", p) p = asarray(p) if issubdtype(p.dtype, np.complexfloating): raise ValueError("jnp.unwrap does not support complex inputs.") if p.shape[axis] == 0: - return _promote_dtypes_inexact(p)[0] + return util._promote_dtypes_inexact(p)[0] if discont is None: discont = period / 2 interval = period / 2 dd = diff(p, axis=axis) - ddmod = mod(dd + interval, period) - interval + ddmod = ufuncs.mod(dd + interval, period) - interval ddmod = where((ddmod == -interval) & (dd > 0), interval, ddmod) - ph_correct = where(abs(dd) < discont, 0, ddmod - dd) + ph_correct = where(ufuncs.abs(dd) < discont, 0, ddmod - dd) up = concatenate(( lax.slice_in_dim(p, 0, 1, axis=axis), - lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis) + lax.slice_in_dim(p, 1, None, axis=axis) + reductions.cumsum(ph_correct, axis=axis) ), axis=axis) return up @@ -1687,7 +1668,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], return array stat_funcs: Dict[str, PadStatFunc] = { - "maximum": amax, "minimum": amin, "mean": mean, "median": median} + "maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": median} pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width") pad_width_arr = np.array(pad_width) @@ -1726,14 +1707,14 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], "not implemented modes") -@_wraps(np.pad, lax_description="""\ +@util._wraps(np.pad, lax_description="""\ Unlike numpy, JAX "function" mode's argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, "function" mode's argument should modify a rank 1 array in-place.) """) def pad(array: ArrayLike, pad_width: PadValueLike[int], mode: Union[str, Callable[..., Any]] = "constant", **kwargs) -> Array: - _check_arraylike("pad", array) + util._check_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") if pad_width and not _all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -1772,7 +1753,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int], ### Array-creation functions -@_wraps(np.stack, skip_params=['out']) +@util._wraps(np.stack, skip_params=['out']) def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], axis: int = 0, out: None = None, dtype: Optional[DTypeLike] = None) -> Array: if not len(arrays): @@ -1783,7 +1764,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - _stackable(*arrays) or _check_arraylike("stack", *arrays) + util._stackable(*arrays) or util._check_arraylike("stack", *arrays) shape0 = shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -1793,9 +1774,9 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) -@_wraps(np.tile) +@util._wraps(np.tile) def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: - _stackable(A) or _check_arraylike("tile", A) + util._stackable(A) or util._check_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -1825,12 +1806,12 @@ def _concatenate_array(arr: ArrayLike, axis: Optional[int], dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)] return lax.reshape(arr, shape, dimensions) -@_wraps(np.concatenate) +@util._wraps(np.concatenate) def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array: if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - _stackable(*arrays) or _check_arraylike("concatenate", *arrays) + util._stackable(*arrays) or util._check_arraylike("concatenate", *arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if ndim(arrays[0]) == 0: @@ -1841,7 +1822,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr] axis = _canonicalize_axis(axis, ndim(arrays[0])) if dtype is None: - arrays_out = _promote_dtypes(*arrays) + arrays_out = util._promote_dtypes(*arrays) else: arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a @@ -1854,7 +1835,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], return arrays_out[0] -@_wraps(np.vstack) +@util._wraps(np.vstack) def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: if isinstance(tup, (np.ndarray, Array)): @@ -1865,7 +1846,7 @@ def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], row_stack = vstack -@_wraps(np.hstack) +@util._wraps(np.hstack) def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: if isinstance(tup, (np.ndarray, Array)): @@ -1877,7 +1858,7 @@ def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) -@_wraps(np.dstack) +@util._wraps(np.dstack) def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: if isinstance(tup, (np.ndarray, Array)): @@ -1887,7 +1868,7 @@ def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], return concatenate(arrs, axis=2, dtype=dtype) -@_wraps(np.column_stack) +@util._wraps(np.column_stack) def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array: if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup @@ -1896,12 +1877,12 @@ def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array: return concatenate(arrs, 1) -@_wraps(np.choose, skip_params=['out']) +@util._wraps(np.choose, skip_params=['out']) def choose(a: ArrayLike, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - _check_arraylike('choose', a, *choices) + util._check_arraylike('choose', a, *choices) if not issubdtype(_dtype(a), integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -1910,7 +1891,7 @@ def choose(a: ArrayLike, choices: Sequence[ArrayLike], arr: Array = core.concrete_or_error(asarray, a, "The error occurred because jnp.choose was jit-compiled" " with mode='raise'. Use mode='wrap' or mode='clip' instead.") - if any((arr < 0) | (arr >= N)): + if reductions.any((arr < 0) | (arr >= N)): raise ValueError("invalid entry in choice array") elif mode == 'wrap': arr = asarray(a) % N @@ -1943,13 +1924,13 @@ def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]: else: return asarray(xs), 1 -@_wraps(np.block) +@util._wraps(np.block) @jit def block(arrays: Union[ArrayLike, List[ArrayLike]]) -> Array: out, _ = _block(arrays) return out -@_wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]: if len(arys) == 1: @@ -1959,7 +1940,7 @@ def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]: return [atleast_1d(arr) for arr in arys] -@_wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]: if len(arys) == 1: @@ -1974,7 +1955,7 @@ def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]: return [atleast_2d(arr) for arr in arys] -@_wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_3d(*arys: ArrayLike) -> Union[Array, List[Array]]: if len(arys) == 1: @@ -1997,7 +1978,7 @@ available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at https://jax.readthedocs.io/en/latest/faq.html). """ -@_wraps(np.array, lax_description=_ARRAY_DOC) +@util._wraps(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, order: Optional[str] = "K", ndmin: int = 0) -> Array: if order is not None and order != "K": @@ -2082,54 +2063,54 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x -@_wraps(np.asarray, lax_description=_ARRAY_DOC) +@util._wraps(np.asarray, lax_description=_ARRAY_DOC) def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "asarray") dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype return array(a, dtype=dtype, copy=False, order=order) # type: ignore -@_wraps(np.copy, lax_description=_ARRAY_DOC) +@util._wraps(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: Optional[str] = None) -> Array: - _check_arraylike("copy", a) + util._check_arraylike("copy", a) return array(a, copy=True, order=order) -@_wraps(np.zeros_like) +@util._wraps(np.zeros_like) def zeros_like(a: ArrayLike, dtype: Optional[DTypeLike] = None, shape: Any = None) -> Array: - _check_arraylike("zeros_like", a) + util._check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: shape = canonicalize_shape(shape) return lax.full_like(a, 0, dtype, shape) -@_wraps(np.ones_like) +@util._wraps(np.ones_like) def ones_like(a: ArrayLike, dtype: Optional[DTypeLike] = None, shape: Any = None) -> Array: - _check_arraylike("ones_like", a) + util._check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: shape = canonicalize_shape(shape) return lax.full_like(a, 1, dtype, shape) -@_wraps(np.empty_like, lax_description="""\ +@util._wraps(np.empty_like, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty_like(prototype: ArrayLike, dtype: Optional[DTypeLike] = None, shape: Any = None) -> Array: - _check_arraylike("empty_like", prototype) + util._check_arraylike("empty_like", prototype) dtypes.check_user_dtype_supported(dtype, "empty_like") return zeros_like(prototype, dtype=dtype, shape=shape) -@_wraps(np.full) +@util._wraps(np.full) def full(shape: Any, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "full") - _check_arraylike("full", fill_value) + util._check_arraylike("full", fill_value) if ndim(fill_value) == 0: shape = canonicalize_shape(shape) return lax.full(shape, fill_value, dtype) @@ -2137,11 +2118,11 @@ def full(shape: Any, fill_value: ArrayLike, return broadcast_to(asarray(fill_value, dtype=dtype), shape) -@_wraps(np.full_like) +@util._wraps(np.full_like) def full_like(a: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None, shape: Any = None) -> Array: dtypes.check_user_dtype_supported(dtype, "full_like") - _check_arraylike("full_like", a, fill_value) + util._check_arraylike("full_like", a, fill_value) if shape is not None: shape = canonicalize_shape(shape) if ndim(fill_value) == 0: @@ -2152,7 +2133,7 @@ def full_like(a: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = return broadcast_to(asarray(fill_value, dtype=dtype), shape) -@_wraps(np.zeros) +@util._wraps(np.zeros) def zeros(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") @@ -2160,7 +2141,7 @@ def zeros(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: shape = canonicalize_shape(shape) return lax.full(shape, 0, _jnp_dtype(dtype)) -@_wraps(np.ones) +@util._wraps(np.ones) def ones(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") @@ -2169,7 +2150,7 @@ def ones(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: return lax.full(shape, 1, _jnp_dtype(dtype)) -@_wraps(np.empty, lax_description="""\ +@util._wraps(np.empty, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: @@ -2177,7 +2158,7 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: return zeros(shape, dtype) -@_wraps(np.array_equal) +@util._wraps(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: try: a1, a2 = asarray(a1), asarray(a2) @@ -2187,27 +2168,27 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return bool_(False) eq = asarray(a1 == a2) if equal_nan: - eq = logical_or(eq, logical_and(isnan(a1), isnan(a2))) - return all(eq) + eq = ufuncs.logical_or(eq, ufuncs.logical_and(ufuncs.isnan(a1), ufuncs.isnan(a2))) + return reductions.all(eq) -@_wraps(np.array_equiv) +@util._wraps(np.array_equiv) def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: try: a1, a2 = asarray(a1), asarray(a2) except Exception: return bool_(False) try: - eq = equal(a1, a2) + eq = ufuncs.equal(a1, a2) except ValueError: # shapes are not broadcastable return bool_(False) - return all(eq) + return reductions.all(eq) # General np.from* style functions mostly delegate to numpy. -@_wraps(np.frombuffer) +@util._wraps(np.frombuffer) def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) @@ -2248,12 +2229,12 @@ def fromiter(*args, **kwargs): "because of its potential side-effect of consuming the iterable object; for more information see " "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") -@_wraps(getattr(np, "from_dlpack", None)) +@util._wraps(getattr(np, "from_dlpack", None)) def from_dlpack(x: Any) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top return from_dlpack(x.__dlpack__()) -@_wraps(np.fromfunction) +@util._wraps(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") @@ -2263,12 +2244,12 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) -@_wraps(np.fromstring) +@util._wraps(np.fromstring) def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) -@_wraps(np.eye) +@util._wraps(np.eye) def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0, dtype: Optional[DTypeLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") @@ -2280,13 +2261,13 @@ def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0, return lax_internal._eye(_jnp_dtype(dtype), (N_int, M_int), k) -@_wraps(np.identity) +@util._wraps(np.identity) def identity(n: DimSize, dtype: Optional[DTypeLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype) -@_wraps(np.arange) +@util._wraps(np.arange) def arange(start: DimSize, stop: Optional[DimSize] = None, step: Optional[DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") @@ -2308,7 +2289,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None, start = require(start, msg("stop")) if (not dtypes.issubdtype(start_dtype, np.integer) and not core.is_opaque_dtype(start_dtype)): - ceil_ = ceil if isinstance(start, core.Tracer) else np.ceil + ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil start = ceil_(start).astype(int) # type: ignore return lax.iota(dtype, start) else: @@ -2341,7 +2322,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: Optional[DTypeLike] = None, axis: int = 0) -> Union[Array, Tuple[Array, Array]]: ... -@_wraps(np.linspace) +@util._wraps(np.linspace) def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: Optional[DTypeLike] = None, @@ -2359,7 +2340,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtypes.check_user_dtype_supported(dtype, "linspace") if num < 0: raise ValueError(f"Number of samples, {num}, must be non-negative.") - _check_arraylike("linspace", start, stop) + util._check_arraylike("linspace", start, stop) if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) @@ -2408,7 +2389,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(out, dtype) -@_wraps(np.logspace) +@util._wraps(np.logspace) def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array: @@ -2426,15 +2407,15 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) - _check_arraylike("logspace", start, stop) + util._check_arraylike("logspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) lin = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis) - return lax.convert_element_type(power(base, lin), dtype) + return lax.convert_element_type(ufuncs.power(base, lin), dtype) -@_wraps(np.geomspace) +@util._wraps(np.geomspace) def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") @@ -2450,14 +2431,14 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) - _check_arraylike("geomspace", start, stop) + util._check_arraylike("geomspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) # follow the numpy geomspace convention for negative and complex endpoints - signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2 + signflip = 1 - (1 - ufuncs.sign(ufuncs.real(start))) * (1 - ufuncs.sign(ufuncs.real(stop))) // 2 signflip = signflip.astype(computation_dtype) - res = signflip * logspace(log10(signflip * start), - log10(signflip * stop), num, + res = signflip * logspace(ufuncs.log10(signflip * start), + ufuncs.log10(signflip * stop), num, endpoint=endpoint, base=10.0, dtype=computation_dtype, axis=0) if axis != 0: @@ -2465,10 +2446,10 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) -@_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> List[Array]: - _check_arraylike("meshgrid", *xi) + util._check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: raise ValueError("jax.numpy.meshgrid only supports copy=True") @@ -2487,19 +2468,19 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output -@_wraps(np.i0) +@util._wraps(np.i0) @jit def i0(x: ArrayLike) -> Array: - x_arr, = _promote_args_inexact("i0", x) + x_arr, = util._promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") x_arr = lax.abs(x_arr) return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr)) -@_wraps(np.ix_) +@util._wraps(np.ix_) def ix_(*args: ArrayLike) -> Tuple[Array, ...]: - _check_arraylike("ix", *args) + util._check_arraylike("ix", *args) n = len(args) output = [] for i, a in enumerate(args): @@ -2529,7 +2510,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ... -@_wraps(np.indices) +@util._wraps(np.indices) def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: dimensions = tuple( @@ -2558,11 +2539,11 @@ will be repeated. """ -@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) +@util._wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *, total_repeat_length: Optional[int] = None) -> Array: - _check_arraylike("repeat", a) - core.is_special_dim_size(repeats) or _check_arraylike("repeat", repeats) + util._check_arraylike("repeat", a) + core.is_special_dim_size(repeats) or util._check_arraylike("repeat", repeats) if axis is None: a = ravel(a) @@ -2630,16 +2611,16 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *, # Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat. exclusive_repeats = roll(repeats, shift=1).at[0].set(0) # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3] - scatter_indices = cumsum(exclusive_repeats) + scatter_indices = reductions.cumsum(exclusive_repeats) # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0] block_split_indicators = zeros([total_repeat_length], dtype=int32) block_split_indicators = block_split_indicators.at[scatter_indices].add(1) # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] - gather_indices = cumsum(block_split_indicators) - 1 + gather_indices = reductions.cumsum(block_split_indicators) - 1 return take(a, gather_indices, axis=axis) -@_wraps(np.tri) +@util._wraps(np.tri) def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") M = M if M is not None else N @@ -2647,10 +2628,10 @@ def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> return lax_internal._tri(dtype, (N, M), k) -@_wraps(np.tril) +@util._wraps(np.tril) @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: - _check_arraylike("tril", m) + util._check_arraylike("tril", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") @@ -2659,10 +2640,10 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) -@_wraps(np.triu, update_doc=False) +@util._wraps(np.triu, update_doc=False) @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: - _check_arraylike("triu", m) + util._check_arraylike("triu", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") @@ -2671,11 +2652,11 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) -@_wraps(np.trace, skip_params=['out']) +@util._wraps(np.trace, skip_params=['out']) @partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[DTypeLike] = None, out: None = None) -> Array: - _check_arraylike("trace", a) + util._check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") dtypes.check_user_dtype_supported(dtype, "trace") @@ -2693,11 +2674,11 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, # Mask out the diagonal and reduce. a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool), a, zeros_like(a)) - return sum(a, axis=(-2, -1), dtype=dtype) + return reductions.sum(a, axis=(-2, -1), dtype=dtype) def _wrap_indices_function(f): - @_wraps(f, update_doc=False) + @util._wraps(f, update_doc=False) def wrapper(*args, **kwargs): args = [core.concrete_or_error( None, arg, f"argument {i} of jnp.{f.__name__}()") @@ -2721,7 +2702,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) -@_wraps(np.triu_indices) +@util._wraps(np.triu_indices) def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") @@ -2730,7 +2711,7 @@ def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar return i, j -@_wraps(np.tril_indices) +@util._wraps(np.tril_indices) def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") @@ -2739,19 +2720,19 @@ def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar return i, j -@_wraps(np.triu_indices_from) +@util._wraps(np.triu_indices_from) def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]: arr_shape = shape(arr) return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1]) -@_wraps(np.tril_indices_from) +@util._wraps(np.tril_indices_from) def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]: arr_shape = shape(arr) return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1]) -@_wraps(np.diag_indices) +@util._wraps(np.diag_indices) def diag_indices(n, ndim=2): n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()") ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()") @@ -2763,9 +2744,9 @@ def diag_indices(n, ndim=2): .format(ndim)) return (lax.iota(int_, n),) * ndim -@_wraps(np.diag_indices_from) +@util._wraps(np.diag_indices_from) def diag_indices_from(arr): - _check_arraylike("diag_indices_from", arr) + util._check_arraylike("diag_indices_from", arr) if not arr.ndim >= 2: raise ValueError("input array must be at least 2-d") @@ -2774,10 +2755,10 @@ def diag_indices_from(arr): return diag_indices(arr.shape[0], ndim=arr.ndim) -@_wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): - _check_arraylike("diagonal", a) + util._check_arraylike("diagonal", a) a_shape = shape(a) if ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") @@ -2792,13 +2773,13 @@ def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): return a[..., i, j] if offset >= 0 else a[..., j, i] -@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v, k=0): return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) def _diag(v, k): - _check_arraylike("diag", v) + util._check_arraylike("diag", v) v_shape = shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) @@ -2816,9 +2797,9 @@ jax always returns a two-dimensional array, whereas numpy may return a scalar depending on the type of v. """ -@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC) +@util._wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC) def diagflat(v, k=0): - _check_arraylike("diagflat", v) + util._check_arraylike("diagflat", v) v = ravel(v) v_length = len(v) adj_length = v_length + _abs(k) @@ -2833,12 +2814,12 @@ def diagflat(v, k=0): return res -@_wraps(np.trim_zeros) +@util._wraps(np.trim_zeros) def trim_zeros(filt, trim='fb'): filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros()") nz = (filt == 0) - if all(nz): + if reductions.all(nz): return empty(0, _dtype(filt)) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 @@ -2848,15 +2829,15 @@ def trim_zeros(filt, trim='fb'): def trim_zeros_tol(filt, tol, trim='fb'): filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros_tol()") - nz = (abs(filt) < tol) - if all(nz): + nz = (ufuncs.abs(filt) < tol) + if reductions.all(nz): return empty(0, _dtype(filt)) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] -@_wraps(np.append) +@util._wraps(np.append) @partial(jit, static_argnames=('axis',)) def append(arr, values, axis: Optional[int] = None): if axis is None: @@ -2865,9 +2846,9 @@ def append(arr, values, axis: Optional[int] = None): return concatenate([arr, values], axis=axis) -@_wraps(np.delete) +@util._wraps(np.delete) def delete(arr, obj, axis=None): - _check_arraylike("delete", arr) + util._check_arraylike("delete", arr) if axis is None: arr = ravel(arr) axis = 0 @@ -2891,7 +2872,7 @@ def delete(arr, obj, axis=None): # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. - _check_arraylike("delete", arr, obj) + util._check_arraylike("delete", arr, obj) obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()") if issubdtype(obj.dtype, integer): @@ -2908,9 +2889,9 @@ def delete(arr, obj, axis=None): raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.") return arr[tuple(slice(None) for i in range(axis)) + (mask,)] -@_wraps(np.insert) +@util._wraps(np.insert) def insert(arr, obj, values, axis=None): - _check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) + util._check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) arr = asarray(arr) values = asarray(values) @@ -2960,7 +2941,7 @@ def insert(arr, obj, values, axis=None): return out -@_wraps(np.apply_along_axis) +@util._wraps(np.apply_along_axis) def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): num_dims = ndim(arr) axis = _canonicalize_axis(axis, num_dims) @@ -2972,7 +2953,7 @@ def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): return func(arr) -@_wraps(np.apply_over_axes) +@util._wraps(np.apply_over_axes) def apply_over_axes(func, a, axes): for axis in axes: b = func(a, axis=axis) @@ -2988,11 +2969,11 @@ def apply_over_axes(func, a, axes): ### Tensor contraction operations -@_wraps(np.dot, lax_description=_PRECISION_DOC) +@util._wraps(np.dot, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('precision',), inline=True) def dot(a, b, *, precision=None): # pylint: disable=missing-docstring - _check_arraylike("dot", a, b) - a, b = _promote_dtypes(a, b) + util._check_arraylike("dot", a, b) + a, b = util._promote_dtypes(a, b) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: return lax.mul(a, b) @@ -3007,17 +2988,17 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring return lax.dot_general(a, b, (contract_dims, batch_dims), precision) -@_wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC) +@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('precision',), inline=True) def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring - _check_arraylike("matmul", a, b) + util._check_arraylike("matmul", a, b) for i, x in enumerate((a, b)): if ndim(x) < 1: msg = (f"matmul input operand {i} must have ndim at least 1, " f"but it has ndim {ndim(x)}") raise ValueError(msg) - a, b = _promote_dtypes(a, b) + a, b = util._promote_dtypes(a, b) a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1) a_batch_dims = shape(a)[:-2] if a_is_mat else () @@ -3070,22 +3051,22 @@ def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring return lax.transpose(out, perm) -@_wraps(np.vdot, lax_description=_PRECISION_DOC) +@util._wraps(np.vdot, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('precision',), inline=True) def vdot(a, b, *, precision=None): - _check_arraylike("vdot", a, b) + util._check_arraylike("vdot", a, b) if issubdtype(_dtype(a), complexfloating): - a = conj(a) + a = ufuncs.conj(a) return dot(a.ravel(), b.ravel(), precision=precision) -@_wraps(np.tensordot, lax_description=_PRECISION_DOC) +@util._wraps(np.tensordot, lax_description=_PRECISION_DOC) def tensordot(a, b, axes=2, *, precision=None): - _check_arraylike("tensordot", a, b) + util._check_arraylike("tensordot", a, b) a_ndim = ndim(a) b_ndim = ndim(b) - a, b = _promote_dtypes(a, b) + a, b = util._promote_dtypes(a, b) if type(axes) is int: if axes > _min(a_ndim, b_ndim): msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})" @@ -3120,7 +3101,7 @@ rather, the specified ``precision`` is forwarded to each ``dot_general`` call us the implementation. """ -@_wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) +@util._wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) def einsum(*operands, out=None, optimize='optimal', precision=None, _use_xeinsum=False): if out is not None: @@ -3166,7 +3147,7 @@ def _default_poly_einsum_handler(*operands, **kwargs): contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions -@_wraps(np.einsum_path) +@util._wraps(np.einsum_path) def einsum_path(subscripts, *operands, optimize='greedy'): # using einsum_call=True here is an internal api for opt_einsum return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) @@ -3178,7 +3159,7 @@ def _removechars(s, chars): def _einsum(operands: Sequence, contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]], precision): - operands = list(_promote_dtypes(*operands)) + operands = list(util._promote_dtypes(*operands)) def sum(x, axes): return lax.reduce(x, np.array(0, x.dtype), lax.add if x.dtype != bool_ else lax.bitwise_or, axes) @@ -3310,7 +3291,7 @@ def _einsum(operands: Sequence, return operands[0] -@_wraps(np.inner, lax_description=_PRECISION_DOC) +@util._wraps(np.inner, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('precision',), inline=True) def inner(a, b, *, precision=None): if ndim(a) == 0 or ndim(b) == 0: @@ -3318,15 +3299,15 @@ def inner(a, b, *, precision=None): return tensordot(a, b, (-1, -1), precision=precision) -@_wraps(np.outer, skip_params=['out']) +@util._wraps(np.outer, skip_params=['out']) @partial(jit, inline=True) def outer(a, b, out=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") - a, b = _promote_dtypes(a, b) + a, b = util._promote_dtypes(a, b) return ravel(a)[:, None] * ravel(b)[None, :] -@_wraps(np.cross) +@util._wraps(np.cross) @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: Optional[int] = None): @@ -3353,10 +3334,10 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) -@_wraps(np.kron) +@util._wraps(np.kron) @jit def kron(a, b): - a, b = _promote_dtypes(a, b) + a, b = util._promote_dtypes(a, b) if ndim(a) < ndim(b): a = expand_dims(a, range(ndim(b) - ndim(a))) elif ndim(b) < ndim(a): @@ -3367,10 +3348,10 @@ def kron(a, b): return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) -@_wraps(np.vander) +@util._wraps(np.vander) @partial(jit, static_argnames=('N', 'increasing')) def vander(x, N=None, increasing=False): - _check_arraylike("vander", x) + util._check_arraylike("vander", x) x = asarray(x) if x.ndim != 1: raise ValueError("x must be a one-dimensional array") @@ -3383,7 +3364,7 @@ def vander(x, N=None, increasing=False): if not increasing: iota = lax.sub(_lax_const(iota, N - 1), iota) - return power(x[..., None], expand_dims(iota, tuple(range(x.ndim)))) + return ufuncs.power(x[..., None], expand_dims(iota, tuple(range(x.ndim)))) ### Misc @@ -3397,7 +3378,7 @@ the indices of the first ``size`` True elements will be returned; if there are f nonzero elements than `size` indicates, the index arrays will be zero-padded. """ -@_wraps(np.argwhere, +@util._wraps(np.argwhere, lax_description=_dedent(""" Because the size of the output of ``argwhere`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -3417,13 +3398,13 @@ def argwhere(a, *, size=None, fill_value=None): return result.reshape(result.shape[0], ndim(a)) -@_wraps(np.argmax, skip_params=['out']) +@util._wraps(np.argmax, skip_params=['out']) def argmax(a, axis: Optional[int] = None, out=None, keepdims=None): return _argmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmax(a, axis: Optional[int] = None, out=None, keepdims=False): - _check_arraylike("argmax", a) + util._check_arraylike("argmax", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") if axis is None: @@ -3437,13 +3418,13 @@ def _argmax(a, axis: Optional[int] = None, out=None, keepdims=False): result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result -@_wraps(np.argmin, skip_params=['out']) +@util._wraps(np.argmin, skip_params=['out']) def argmin(a, axis: Optional[int] = None, out=None, keepdims=None): return _argmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmin(a, axis: Optional[int] = None, out=None, keepdims=False): - _check_arraylike("argmin", a) + util._check_arraylike("argmin", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") if axis is None: @@ -3463,7 +3444,7 @@ Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise an error. """ -@_wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) +@util._wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) def nanargmax(a, axis: Optional[int] = None, out : Any = None, keepdims : Optional[bool] = None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") @@ -3471,15 +3452,15 @@ def nanargmax(a, axis: Optional[int] = None, out : Any = None, keepdims : Option @partial(jit, static_argnames=('axis', 'keepdims')) def _nanargmax(a, axis: Optional[int] = None, keepdims: bool = False): - _check_arraylike("nanargmax", a) + util._check_arraylike("nanargmax", a) if not issubdtype(_dtype(a), inexact): return argmax(a, axis=axis, keepdims=keepdims) - nan_mask = isnan(a) + nan_mask = ufuncs.isnan(a) a = where(nan_mask, -inf, a) res = argmax(a, axis=axis, keepdims=keepdims) - return where(all(nan_mask, axis=axis, keepdims=keepdims), -1, res) + return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) +@util._wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) def nanargmin(a, axis: Optional[int] = None, out : Any = None, keepdims : Optional[bool] = None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") @@ -3487,19 +3468,19 @@ def nanargmin(a, axis: Optional[int] = None, out : Any = None, keepdims : Option @partial(jit, static_argnames=('axis', 'keepdims')) def _nanargmin(a, axis: Optional[int] = None, keepdims : bool = False): - _check_arraylike("nanargmin", a) + util._check_arraylike("nanargmin", a) if not issubdtype(_dtype(a), inexact): return argmin(a, axis=axis, keepdims=keepdims) - nan_mask = isnan(a) + nan_mask = ufuncs.isnan(a) a = where(nan_mask, inf, a) res = argmin(a, axis=axis, keepdims=keepdims) - return where(all(nan_mask, axis=axis, keepdims=keepdims), -1, res) + return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@_wraps(np.sort) +@util._wraps(np.sort) @partial(jit, static_argnames=('axis', 'kind', 'order')) def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None): - _check_arraylike("sort", a) + util._check_arraylike("sort", a) if kind != 'quicksort': warnings.warn("'kind' argument to sort is ignored.") if order is not None: @@ -3510,14 +3491,14 @@ def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None): else: return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a))) -@_wraps(np.sort_complex) +@util._wraps(np.sort_complex) @jit def sort_complex(a): - _check_arraylike("sort_complex", a) + util._check_arraylike("sort_complex", a) a = lax.sort(a, dimension=0) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) -@_wraps(np.lexsort) +@util._wraps(np.lexsort) @partial(jit, static_argnames=('axis',)) def lexsort(keys, axis=-1): keys = tuple(keys) @@ -3538,10 +3519,10 @@ Only :code:`kind='stable'` is supported. Other :code:`kind` values will produce a warning and be treated as if they were :code:`'stable'`. """ -@_wraps(np.argsort, lax_description=_ARGSORT_DOC) +@util._wraps(np.argsort, lax_description=_ARGSORT_DOC) @partial(jit, static_argnames=('axis', 'kind', 'order')) def argsort(a, axis: Optional[int] = -1, kind='stable', order=None): - _check_arraylike("argsort", a) + util._check_arraylike("argsort", a) if kind != 'stable': warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts " "are supported.") @@ -3558,14 +3539,14 @@ def argsort(a, axis: Optional[int] = -1, kind='stable', order=None): return perm -@_wraps(np.msort) +@util._wraps(np.msort) def msort(a): # TODO(jakevdp): remove msort after Feb 2023 warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning) return sort(a, axis=0) -@_wraps(np.partition, lax_description=""" +@util._wraps(np.partition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more @@ -3577,7 +3558,7 @@ NaNs which have the negative bit set are sorted to the beginning of the array. @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # TODO(jakevdp): handle NaN values like numpy. - _check_arraylike("partition", a) + util._check_arraylike("partition", a) arr = asarray(a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.partition for complex dtype is not implemented.") @@ -3591,7 +3572,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) -@_wraps(np.argpartition, lax_description=""" +@util._wraps(np.argpartition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more @@ -3603,7 +3584,7 @@ NaNs which have the negative bit set are sorted to the beginning of the array. @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # TODO(jakevdp): handle NaN values like numpy. - _check_arraylike("partition", a) + util._check_arraylike("partition", a) arr = asarray(a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.") @@ -3640,25 +3621,25 @@ def _roll(a, shift, axis): np.broadcast_to(axis, b_shape)): i = _canonicalize_axis(i, a_ndim) a_shape_i = array(a_shape[i], dtype=np.int32) - x = remainder(lax.convert_element_type(x, np.int32), + x = ufuncs.remainder(lax.convert_element_type(x, np.int32), lax.max(a_shape_i, np.int32(1))) a = lax.concatenate((a, a), i) a = lax.dynamic_slice_in_dim(a, a_shape_i - x, a_shape[i], axis=i) return a -@_wraps(np.roll) +@util._wraps(np.roll) def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None): - _check_arraylike("roll", a,) + util._check_arraylike("roll", a,) if isinstance(axis, list): axis = tuple(axis) return _roll(a, shift, axis) -@_wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) +@util._wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a, axis: int, start=0): - _check_arraylike("rollaxis", a) + util._check_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = ndim(a) axis = _canonicalize_axis(axis, a_ndim) @@ -3671,10 +3652,10 @@ def rollaxis(a, axis: int, start=0): return moveaxis(a, axis, start) -@_wraps(np.packbits) +@util._wraps(np.packbits) @partial(jit, static_argnames=('axis', 'bitorder')) def packbits(a, axis: Optional[int] = None, bitorder='big'): - _check_arraylike("packbits", a) + util._check_arraylike("packbits", a) if not (issubdtype(_dtype(a), integer) or issubdtype(_dtype(a), bool_)): raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: @@ -3699,10 +3680,10 @@ def packbits(a, axis: Optional[int] = None, bitorder='big'): return swapaxes(packed, axis, -1) -@_wraps(np.unpackbits) +@util._wraps(np.unpackbits) @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): - _check_arraylike("unpackbits", a) + util._check_arraylike("unpackbits", a) if _dtype(a) != uint8: raise TypeError("Expected an input array of unsigned byte data type") if bitorder not in ['little', 'big']: @@ -3719,7 +3700,7 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): return swapaxes(unpacked, axis, -1) -@_wraps(np.take, skip_params=['out'], +@util._wraps(np.take, skip_params=['out'], lax_description=""" By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). @@ -3751,7 +3732,7 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.take is not supported.") - _check_arraylike("take", a, indices) + util._check_arraylike("take", a, indices) a = asarray(a) indices = asarray(indices) @@ -3769,7 +3750,7 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None, # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": - indices = mod(indices, _lax_const(indices, a.shape[axis_idx])) + indices = ufuncs.mod(indices, _lax_const(indices, a.shape[axis_idx])) gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS elif mode == "clip": gather_mode = lax.GatherScatterMode.CLIP @@ -3824,12 +3805,12 @@ See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds indexing in JAX. """ -@_wraps(np.take_along_axis, update_doc=False, +@util._wraps(np.take_along_axis, update_doc=False, lax_description=TAKE_ALONG_AXIS_DOC) @partial(jit, static_argnames=('axis', 'mode')) def take_along_axis(arr, indices, axis: Optional[int], mode: Optional[Union[str, lax.GatherScatterMode]] = None): - _check_arraylike("take_along_axis", arr, indices) + util._check_arraylike("take_along_axis", arr, indices) index_dtype = dtypes.dtype(indices) if not dtypes.issubdtype(index_dtype, integer): raise TypeError("take_along_axis indices must be of integer type, got " @@ -4435,47 +4416,47 @@ def _static_idx(idx: slice, size: DimSize): return stop + k + 1, start + 1, -step, True -@_wraps(np.blackman) +@util._wraps(np.blackman) def blackman(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.blackman") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: return ones(M, dtype) n = lax.iota(dtype, M) - return 0.42 - 0.5 * cos(2 * pi * n / (M - 1)) + 0.08 * cos(4 * pi * n / (M - 1)) + return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) -@_wraps(np.bartlett) +@util._wraps(np.bartlett) def bartlett(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: return ones(M, dtype) n = lax.iota(dtype, M) - return 1 - abs(2 * n + 1 - M) / (M - 1) + return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) -@_wraps(np.hamming) +@util._wraps(np.hamming) def hamming(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.hamming") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: return ones(M, dtype) n = lax.iota(dtype, M) - return 0.54 - 0.46 * cos(2 * pi * n / (M - 1)) + return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) -@_wraps(np.hanning) +@util._wraps(np.hanning) def hanning(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.hanning") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: return ones(M, dtype) n = lax.iota(dtype, M) - return 0.5 * (1 - cos(2 * pi * n / (M - 1))) + return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) -@_wraps(np.kaiser) +@util._wraps(np.kaiser) def kaiser(M: int, beta: ArrayLike) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") dtype = dtypes.canonicalize_dtype(float_) @@ -4483,12 +4464,12 @@ def kaiser(M: int, beta: ArrayLike) -> Array: return ones(M, dtype) n = lax.iota(dtype, M) alpha = 0.5 * (M - 1) - return i0(beta * sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta) + return i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta) def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array: x1, x2 = xs - return any(x2 != 0) + return reductions.any(x2 != 0) def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]: x1, x2 = xs @@ -4496,40 +4477,40 @@ def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) -@_wraps(np.gcd, module='numpy') +@util._wraps(np.gcd, module='numpy') @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: - _check_arraylike("gcd", x1, x2) - x1, x2 = _promote_dtypes(x1, x2) + util._check_arraylike("gcd", x1, x2) + x1, x2 = util._promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): raise ValueError("Arguments to jax.numpy.gcd must be integers.") x1, x2 = broadcast_arrays(x1, x2) - gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2))) + gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (ufuncs.abs(x1), ufuncs.abs(x2))) return gcd -@_wraps(np.lcm, module='numpy') +@util._wraps(np.lcm, module='numpy') @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: - _check_arraylike("lcm", x1, x2) - x1, x2 = _promote_dtypes(x1, x2) - x1, x2 = abs(x1), abs(x2) + util._check_arraylike("lcm", x1, x2) + x1, x2 = util._promote_dtypes(x1, x2) + x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) if not issubdtype(_dtype(x1), integer): raise ValueError("Arguments to jax.numpy.lcm must be integers.") d = gcd(x1, x2) return where(d == 0, _lax_const(d, 0), - multiply(x1, floor_divide(x2, d))) + ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) -@_wraps(np.extract) +@util._wraps(np.extract) def extract(condition: ArrayLike, arr: ArrayLike) -> Array: return compress(ravel(condition), ravel(arr)) -@_wraps(np.compress, skip_params=['out']) +@util._wraps(np.compress, skip_params=['out']) def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None, out: None = None) -> Array: - _check_arraylike("compress", condition, a) + util._check_arraylike("compress", condition, a) condition_arr = asarray(condition).astype(bool) if out is not None: raise NotImplementedError("The 'out' argument to jnp.compress is not supported.") @@ -4541,24 +4522,24 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None, else: arr = moveaxis(a, axis, 0) condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:] - if any(extra): + if reductions.any(extra): raise ValueError("condition contains entries that are out of bounds") arr = arr[:condition_arr.shape[0]] return moveaxis(arr[condition_arr], 0, axis) -@_wraps(np.cov) +@util._wraps(np.cov) @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True, bias: bool = False, ddof: Optional[int] = None, fweights: Optional[ArrayLike] = None, aweights: Optional[ArrayLike] = None) -> Array: if y is not None: - m, y = _promote_args_inexact("cov", m, y) + m, y = util._promote_args_inexact("cov", m, y) if y.ndim > 2: raise ValueError("y has more than 2 dimensions") else: - m, = _promote_args_inexact("cov", m) + m, = util._promote_args_inexact("cov", m) if m.ndim > 2: raise ValueError("m has more than 2 dimensions") # same as numpy error @@ -4579,7 +4560,7 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True, w: Optional[Array] = None if fweights is not None: - _check_arraylike("cov", fweights) + util._check_arraylike("cov", fweights) if ndim(fweights) > 1: raise RuntimeError("cannot handle multidimensional fweights") if shape(fweights)[0] != X.shape[1]: @@ -4587,18 +4568,18 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True, if not issubdtype(_dtype(fweights), integer): raise TypeError("fweights must be integer.") # Ensure positive fweights; note that numpy raises an error on negative fweights. - w = asarray(abs(fweights)) + w = asarray(ufuncs.abs(fweights)) if aweights is not None: - _check_arraylike("cov", aweights) + util._check_arraylike("cov", aweights) if ndim(aweights) > 1: raise RuntimeError("cannot handle multidimensional aweights") if shape(aweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and aweights") # Ensure positive aweights: note that numpy raises an error for negative aweights. - aweights = abs(aweights) + aweights = ufuncs.abs(aweights) w = asarray(aweights) if w is None else w * asarray(aweights) - avg, w_sum = average(X, axis=1, weights=w, returned=True) + avg, w_sum = reductions.average(X, axis=1, weights=w, returned=True) w_sum = w_sum[0] if w is None: @@ -4608,41 +4589,41 @@ def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True, elif aweights is None: f = w_sum - ddof else: - f = w_sum - ddof * sum(w * aweights) / w_sum + f = w_sum - ddof * reductions.sum(w * aweights) / w_sum X = X - avg[:, None] X_T = X.T if w is None else (X * lax.broadcast_to_rank(w, X.ndim)).T - return true_divide(dot(X, X_T.conj()), f).squeeze() + return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() -@_wraps(np.corrcoef) +@util._wraps(np.corrcoef) @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array: - _check_arraylike("corrcoef", x) + util._check_arraylike("corrcoef", x) c = cov(x, y, rowvar) if len(shape(c)) == 0: # scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise - return divide(c, c) + return ufuncs.divide(c, c) d = diag(c) - stddev = sqrt(real(d)).astype(c.dtype) + stddev = ufuncs.sqrt(ufuncs.real(d)).astype(c.dtype) c = c / stddev[:, None] / stddev[None, :] - real_part = clip(real(c), -1, 1) + real_part = clip(ufuncs.real(c), -1, 1) if iscomplexobj(c): - complex_part = clip(imag(c), -1, 1) + complex_part = clip(ufuncs.imag(c), -1, 1) c = lax.complex(real_part, complex_part) else: c = real_part return c -@_wraps(np.quantile, skip_params=['out', 'overwrite_input']) +@util._wraps(np.quantile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: - _check_arraylike("quantile", a, q) + util._check_arraylike("quantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.quantile does not support overwrite_input=True or " "out != None") @@ -4652,13 +4633,13 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, .. "Use 'method=' instead.", DeprecationWarning) return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False) -@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) +@util._wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: - _check_arraylike("nanquantile", a, q) + util._check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -4673,7 +4654,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " "'midpoint', or 'nearest'") - a, = _promote_dtypes_inexact(a) + a, = util._promote_dtypes_inexact(a) keepdim = [] if issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") @@ -4709,10 +4690,10 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], a_shape = shape(a) if squash_nans: - a = where(isnan(a), nan, a) # Ensure nans are positive so they sort to the end. + a = where(ufuncs.isnan(a), nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(logical_not(isnan(a)), axis=axis, dtype=q.dtype, - keepdims=keepdims) + counts = reductions.sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, + keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) @@ -4738,7 +4719,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], index[axis] = high high_value = a[tuple(index)] else: - a = where(any(isnan(a), axis=axis, keepdims=True), nan, a) + a = where(reductions.any(ufuncs.isnan(a), axis=axis, keepdims=True), nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -4823,7 +4804,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) -@_wraps(np.searchsorted, skip_params=['sorter'], +@util._wraps(np.searchsorted, skip_params=['sorter'], extra_params=_dedent(""" method : str One of 'scan' (default), 'sort' or 'compare_all'. Controls the method used by the @@ -4834,7 +4815,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt @partial(jit, static_argnames=('side', 'sorter', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: None = None, *, method: str = 'scan') -> Array: - _check_arraylike("searchsorted", a, v) + util._check_arraylike("searchsorted", a, v) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") @@ -4845,7 +4826,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', raise NotImplementedError("sorter is not implemented") if ndim(a) != 1: raise ValueError("a should be 1-dimensional") - a, v = _promote_dtypes(a, v) + a, v = util._promote_dtypes(a, v) dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 if len(a) == 0: return zeros_like(v, dtype=dtype) @@ -4856,10 +4837,10 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) -@_wraps(np.digitize) +@util._wraps(np.digitize) @partial(jit, static_argnames=('right',)) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: - _check_arraylike("digitize", x, bins) + util._check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) if bins_arr.ndim != 1: @@ -4879,11 +4860,11 @@ Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in See the :func:`jax.lax.switch` documentation for more information. """ -@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC) +@util._wraps(np.piecewise, lax_description=_PIECEWISE_DOC) def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], funclist: List[Union[ArrayLike, Callable[..., Array]]], *args, **kw) -> Array: - _check_arraylike("piecewise", x) + util._check_arraylike("piecewise", x) nc, nf = len(condlist), len(funclist) if nf == nc + 1: funclist = funclist[-1:] + funclist[:-1] @@ -4903,7 +4884,7 @@ def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike], *args, **kw) -> Array: funcdict = dict(funcs) funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)] - indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0) + indices = argmax(reductions.cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0) dtype = _dtype(x) def _call(f): return lambda x: f(x, *args, **kw).astype(dtype) @@ -4913,46 +4894,46 @@ def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike], return vectorize(lax.switch, excluded=(1,))(indices, funclist, x) -@_wraps(np.percentile, skip_params=['out', 'overwrite_input']) +@util._wraps(np.percentile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: - _check_arraylike("percentile", a, q) - q, = _promote_dtypes_inexact(q) + util._check_arraylike("percentile", a, q) + q, = util._promote_dtypes_inexact(q) return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims) -@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) +@util._wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: - _check_arraylike("nanpercentile", a, q) - q = true_divide(q, float32(100.0)) + util._check_arraylike("nanpercentile", a, q) + q = ufuncs.true_divide(q, float32(100.0)) return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims) -@_wraps(np.median, skip_params=['out', 'overwrite_input']) +@util._wraps(np.median, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: - _check_arraylike("median", a) + util._check_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') -@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) +@util._wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: - _check_arraylike("nanmedian", a) + util._check_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') @@ -5028,7 +5009,7 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: if type is not None: raise NotImplementedError("`type` argument of array.view() is not supported.") - _check_arraylike("view", arr) + util._check_arraylike("view", arr) arr = asarray(arr) dtypes.check_user_dtype_supported(dtype, "view") @@ -5085,7 +5066,7 @@ def _notimplemented_flat(self): "consider arr.flatten() instead.") -@_wraps(np.place, lax_description=""" +@util._wraps(np.place, lax_description=""" Numpy function :func:`numpy.place` is not available in JAX and will raise a :class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place, and in JAX arrays are immutable. A JAX-compatible approach to array updates @@ -5098,7 +5079,7 @@ def place(*args, **kwargs): "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.") -@_wraps(np.put, lax_description=""" +@util._wraps(np.put, lax_description=""" Numpy function :func:`numpy.put` is not available in JAX and will raise a :class:`NotImplementedError`, because ``np.put`` modifies its arguments in-place, and in JAX arrays are immutable. A JAX-compatible approach to array updates @@ -5161,7 +5142,7 @@ _JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl) _HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) def __array_module__(self, types): - if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + if _all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): return jax.numpy else: return NotImplemented @@ -5175,7 +5156,7 @@ def _compress_method(a: ArrayLike, condition: ArrayLike, return compress(condition, a, axis, out) -@_wraps(lax.broadcast, lax_description=""" +@util._wraps(lax.broadcast, lax_description=""" Deprecated. Use :func:`jax.lax.broadcast` instead. """) def _deprecated_broadcast(*args, **kwargs): @@ -5185,7 +5166,7 @@ def _deprecated_broadcast(*args, **kwargs): return lax.broadcast(*args, **kwargs) -@_wraps(lax.broadcast, lax_description=""" +@util._wraps(lax.broadcast, lax_description=""" Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead. """) def _deprecated_broadcast_in_dim(*args, **kwargs): @@ -5195,7 +5176,7 @@ def _deprecated_broadcast_in_dim(*args, **kwargs): return lax.broadcast_in_dim(*args, **kwargs) -@_wraps(lax.broadcast, lax_description=""" +@util._wraps(lax.broadcast, lax_description=""" Deprecated. Use :func:`jax.numpy.split` instead. """) def _deprecated_split(*args, **kwargs): @@ -5236,7 +5217,7 @@ def _chunk_iter(x, size): if size > x.shape[0]: yield x else: - num_chunks, tail = divmod(x.shape[0], size) + num_chunks, tail = ufuncs.divmod(x.shape[0], size) for i in range(num_chunks): yield lax.dynamic_slice_in_dim(x, i * size, size) if tail: @@ -5342,8 +5323,6 @@ class _IndexUpdateHelper: return f"_IndexUpdateHelper({repr(self.array)})" Array.at.__doc__ = _IndexUpdateHelper.__doc__ -_power_fn = power -_divide_fn = divide class _IndexUpdateRef: """Helper object to call indexed update functions for an (advanced) index. @@ -5452,7 +5431,7 @@ class _IndexUpdateRef: See :mod:`jax.ops` for details. """ - return _divide_fn( + return ufuncs.divide( self.array, scatter._scatter_update(ones_like(self.array), self.index, values, lax.scatter_mul, @@ -5468,7 +5447,7 @@ class _IndexUpdateRef: See :mod:`jax.ops` for details. """ - return _power_fn( + return ufuncs.power( self.array, scatter._scatter_update(ones_like(self.array), self.index, values, lax.scatter_mul, @@ -5510,52 +5489,52 @@ _array_operators = { "setitem": _unimplemented_setitem, "copy": _copy, "deepcopy": _deepcopy, - "neg": negative, - "pos": positive, - "eq": _defer_to_unrecognized_arg("==", equal), - "ne": _defer_to_unrecognized_arg("!=", not_equal), - "lt": _defer_to_unrecognized_arg("<", less), - "le": _defer_to_unrecognized_arg("<=", less_equal), - "gt": _defer_to_unrecognized_arg(">", greater), - "ge": _defer_to_unrecognized_arg(">=", greater_equal), - "abs": abs, - "add": _defer_to_unrecognized_arg("+", add), - "radd": _defer_to_unrecognized_arg("+", add, swap=True), - "sub": _defer_to_unrecognized_arg("-", subtract), - "rsub": _defer_to_unrecognized_arg("-", subtract, swap=True), - "mul": _defer_to_unrecognized_arg("*", multiply), - "rmul": _defer_to_unrecognized_arg("*", multiply, swap=True), - "div": _defer_to_unrecognized_arg("/", divide), - "rdiv": _defer_to_unrecognized_arg("/", divide, swap=True), - "truediv": _defer_to_unrecognized_arg("/", true_divide), - "rtruediv": _defer_to_unrecognized_arg("/", true_divide, swap=True), - "floordiv": _defer_to_unrecognized_arg("//", floor_divide), - "rfloordiv": _defer_to_unrecognized_arg("//", floor_divide, swap=True), - "divmod": _defer_to_unrecognized_arg("divmod", divmod), - "rdivmod": _defer_to_unrecognized_arg("divmod", divmod, swap=True), - "mod": _defer_to_unrecognized_arg("%", mod), - "rmod": _defer_to_unrecognized_arg("%", mod, swap=True), - "pow": _defer_to_unrecognized_arg("**", power), - "rpow": _defer_to_unrecognized_arg("**", power, swap=True), + "neg": ufuncs.negative, + "pos": ufuncs.positive, + "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), + "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), + "lt": _defer_to_unrecognized_arg("<", ufuncs.less), + "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), + "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), + "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), + "abs": ufuncs.abs, + "add": _defer_to_unrecognized_arg("+", ufuncs.add), + "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), + "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), + "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), + "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), + "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), + "div": _defer_to_unrecognized_arg("/", ufuncs.divide), + "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), + "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), + "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), + "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), + "rfloordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide, swap=True), + "divmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod), + "rdivmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod, swap=True), + "mod": _defer_to_unrecognized_arg("%", ufuncs.mod), + "rmod": _defer_to_unrecognized_arg("%", ufuncs.mod, swap=True), + "pow": _defer_to_unrecognized_arg("**", ufuncs.power), + "rpow": _defer_to_unrecognized_arg("**", ufuncs.power, swap=True), "matmul": _defer_to_unrecognized_arg("@", matmul), "rmatmul": _defer_to_unrecognized_arg("@", matmul, swap=True), - "and": _defer_to_unrecognized_arg("&", bitwise_and), - "rand": _defer_to_unrecognized_arg("&", bitwise_and, swap=True), - "or": _defer_to_unrecognized_arg("|", bitwise_or), - "ror": _defer_to_unrecognized_arg("|", bitwise_or, swap=True), - "xor": _defer_to_unrecognized_arg("^", bitwise_xor), - "rxor": _defer_to_unrecognized_arg("^", bitwise_xor, swap=True), - "invert": bitwise_not, - "lshift": _defer_to_unrecognized_arg("<<", left_shift), - "rshift": _defer_to_unrecognized_arg(">>", right_shift), - "rlshift": _defer_to_unrecognized_arg("<<", left_shift, swap=True), - "rrshift": _defer_to_unrecognized_arg(">>", right_shift, swap=True), + "and": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and), + "rand": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and, swap=True), + "or": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or), + "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), + "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), + "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), + "invert": ufuncs.bitwise_not, + "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), + "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), + "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), + "rrshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift, swap=True), "round": _operator_round, } _array_methods = { - "all": all, - "any": any, + "all": reductions.all, + "any": reductions.any, "argmax": argmax, "argmin": argmin, "argpartition": argpartition, @@ -5563,22 +5542,22 @@ _array_methods = { "astype": _astype, "choose": choose, "clip": _clip, - "conj": conj, - "conjugate": conjugate, + "conj": ufuncs.conj, + "conjugate": ufuncs.conjugate, "compress": _compress_method, "copy": copy, - "cumprod": cumprod, - "cumsum": cumsum, + "cumprod": reductions.cumprod, + "cumsum": reductions.cumsum, "diagonal": diagonal, "dot": dot, "flatten": ravel, "item": _item, - "max": max, - "mean": mean, - "min": min, + "max": reductions.max, + "mean": reductions.mean, + "min": reductions.min, "nonzero": nonzero, - "prod": prod, - "ptp": ptp, + "prod": reductions.prod, + "ptp": reductions.ptp, "ravel": ravel, "repeat": repeat, "reshape": _reshape, @@ -5586,13 +5565,13 @@ _array_methods = { "searchsorted": searchsorted, "sort": sort, "squeeze": squeeze, - "std": std, - "sum": sum, + "std": reductions.std, + "sum": reductions.sum, "swapaxes": swapaxes, "take": take, "trace": trace, "transpose": _transpose, - "var": var, + "var": reductions.var, "view": _view, # Methods exposed in order to avoid circular imports @@ -5609,8 +5588,8 @@ _array_methods = { _array_properties = { "flat": _notimplemented_flat, "T": transpose, - "real": real, - "imag": imag, + "real": ufuncs.real, + "imag": ufuncs.imag, "nbytes": _nbytes, "itemsize": _itemsize, "at": _IndexUpdateHelper, diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 45369e4d0..4ea377ae5 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -27,6 +27,7 @@ from jax import lax from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike from jax._src.util import canonicalize_axis from jax._src.typing import ArrayLike, Array @@ -37,7 +38,7 @@ def _T(x: ArrayLike) -> Array: def _H(x: ArrayLike) -> Array: - return jnp.conjugate(jnp.swapaxes(x, -1, -2)) + return ufuncs.conjugate(jnp.swapaxes(x, -1, -2)) def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @@ -136,12 +137,12 @@ def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array: _check_arraylike("jnp.linalg.matrix_rank", M) M, = _promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: - return jnp.any(M != 0).astype(jnp.int32) + return (M != 0).any().astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) if tol is None: tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps tol = jnp.expand_dims(tol, np.ndim(tol)) - return jnp.sum(S > tol, axis=-1) + return reductions.sum(S > tol, axis=-1) @custom_jvp @@ -149,22 +150,22 @@ def _slogdet_lu(a: Array) -> Tuple[Array, Array]: dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) - is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) + is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype), range(pivot.ndim - 1)) - parity = jnp.count_nonzero(pivot != iota, axis=-1) + parity = reductions.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): - sign = jnp.prod(diag / jnp.abs(diag).astype(diag.dtype), axis=-1) + sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1) else: sign = jnp.array(1, dtype=dtype) - parity = parity + jnp.count_nonzero(diag < 0, axis=-1) + parity = parity + reductions.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where( is_zero, jnp.array(-jnp.inf, dtype=dtype), - jnp.sum(jnp.log(jnp.abs(diag)).astype(dtype), axis=-1)) - return sign, jnp.real(logdet) + reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1)) + return sign, ufuncs.real(logdet) @custom_jvp def _slogdet_qr(a: Array) -> Tuple[Array, Array]: @@ -179,11 +180,11 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]: # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. - log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) - sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) + log_abs_det = jnp.trace(ufuncs.log(ufuncs.abs(a)), axis1=-2, axis2=-1) + sign_diag = reductions.prod(ufuncs.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. - sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) + sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det @_wraps( @@ -217,8 +218,8 @@ def _slogdet_jvp(primals, tangents): sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): - sign_dot = (ans_dot - jnp.real(ans_dot).astype(ans_dot.dtype)) * sign - ans_dot = jnp.real(ans_dot) + sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign + ans_dot = ufuncs.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot) @@ -292,18 +293,18 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]: diag = jnp.diagonal(lu, axis1=-2, axis2=-1) iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype), range(pivots.ndim - 1)) - parity = jnp.count_nonzero(pivots != iota, axis=-1) + parity = reductions.count_nonzero(pivots != iota, axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. - partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] + partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1])) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) - d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) + d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] @@ -344,7 +345,7 @@ def det(a: ArrayLike) -> Array: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: sign, logdet = slogdet(a) - return sign * jnp.exp(logdet).astype(sign.dtype) + return sign * ufuncs.exp(logdet).astype(sign.dtype) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) @@ -424,7 +425,7 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, m, n = arr.shape[-2:] if m == 0 or n == 0: return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) - arr = jnp.conj(arr) + arr = ufuncs.conj(arr) if rcond is None: max_rows_cols = max(arr.shape[-2:]) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) @@ -435,7 +436,7 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) cutoff = rcond * s[..., 0:1] s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) - res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]), + res = jnp.matmul(_T(vh), ufuncs.divide(_T(u), s[..., jnp.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) @@ -495,7 +496,7 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None, # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) @@ -505,21 +506,21 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None, num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, - keepdims=keepdims)) + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, + keepdims=keepdims)) elif ord == jnp.inf: - return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) + return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: - return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) + return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: - return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, - axis=axis, keepdims=keepdims) + return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, + axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. - return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) + return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif isinstance(ord, str): msg = f"Invalid order '{ord}' for vector norm." if ord == "inf": @@ -528,46 +529,46 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None, msg += "Use '-jax.numpy.inf' instead." raise ValueError(msg) else: - abs_x = jnp.abs(x) + abs_x = ufuncs.abs(x) ord_arr = lax_internal._const(abs_x, ord) ord_inv = lax_internal._const(abs_x, 1. / ord_arr) - out = jnp.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) - return jnp.power(out, ord_inv) + out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) + return ufuncs.power(out, ord_inv) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, - keepdims=keepdims)) + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, + keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 - return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), - axis=col_axis, keepdims=keepdims) + return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 - return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), - axis=col_axis, keepdims=keepdims) + return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 - return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), + return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 - return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), + return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: - reducer = jnp.amax + reducer = reductions.amax elif ord == -2: - reducer = jnp.amin + reducer = reductions.amin else: # `sum` takes an extra dtype= argument, unlike `amax` and `amin`. - reducer = jnp.sum # type: ignore[assignment] + reducer = reductions.sum # type: ignore[assignment] y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: y = jnp.expand_dims(y, axis) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 23d7b9bd6..23d74e98b 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -24,9 +24,11 @@ from jax import lax from jax._src import dtypes from jax._src import core from jax._src.numpy.lax_numpy import ( - all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, - diag, dot, finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, - trim_zeros_tol, true_divide, vander, zeros) + arange, argmin, array, asarray, atleast_1d, concatenate, convolve, + diag, dot, finfo, full, ones, outer, roll, trim_zeros, + trim_zeros_tol, vander, zeros) +from jax._src.numpy.ufuncs import maximum, true_divide, sqrt +from jax._src.numpy.reductions import all from jax._src.numpy import linalg from jax._src.numpy.util import ( _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 2949415d3..4dfedda56 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -27,9 +27,11 @@ from jax._src import core from jax._src import dtypes from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import ( - any, append, arange, array, asarray, concatenate, cumsum, diff, - empty, full_like, isnan, lexsort, moveaxis, nonzero, ones, ravel, + append, arange, array, asarray, concatenate, diff, + empty, full_like, lexsort, moveaxis, nonzero, ones, ravel, sort, where, zeros) +from jax._src.numpy.reductions import any, cumsum +from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import _check_arraylike, _wraps from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 0c3529576..c709ad6f0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -27,6 +27,8 @@ from jax._src import dtypes from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import reductions +from jax._src.numpy.util import _check_arraylike, _promote_dtypes Array = Any @@ -98,7 +100,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, if core.is_empty_shape(indexer.slice_shape): return x - x, y = jnp._promote_dtypes(x, y) + x, y = _promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) @@ -155,7 +157,7 @@ def _segment_update(name: str, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: - jnp._check_arraylike(name, data, segment_ids) + _check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) @@ -239,7 +241,7 @@ def segment_sum(data: Array, """ return _segment_update( "segment_sum", data, segment_ids, lax.scatter_add, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode) + indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode) def segment_prod(data: Array, @@ -295,7 +297,7 @@ def segment_prod(data: Array, """ return _segment_update( "segment_prod", data, segment_ids, lax.scatter_mul, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode) + indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode) def segment_max(data: Array, @@ -350,7 +352,7 @@ def segment_max(data: Array, """ return _segment_update( "segment_max", data, segment_ids, lax.scatter_max, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode) + indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode) def segment_min(data: Array, @@ -405,4 +407,4 @@ def segment_min(data: Array, """ return _segment_update( "segment_min", data, segment_ids, lax.scatter_min, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode) + indices_are_sorted, unique_indices, bucket_size, reductions.min, mode=mode) diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 81d736550..d7ffa3746 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -17,7 +17,8 @@ from typing import overload, Literal, Optional, Tuple, Union import jax from jax import lax from jax import numpy as jnp -from jax._src.numpy.lax_numpy import _reduction_dims, _promote_args_inexact +from jax._src.numpy.reductions import _reduction_dims +from jax._src.numpy.util import _promote_args_inexact from jax._src.typing import Array, ArrayLike import numpy as np diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 743f86f45..2af62225d 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -42,7 +42,8 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils from jax._src.lib import gpu_prng from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy import lax_numpy +from jax._src.numpy.lax_numpy import _set_device_array_base_attributes +from jax._src.numpy.util import _register_stackable from jax._src.sharding import ( NamedSharding, PmapSharding, GSPMDSharding) from jax._src.util import canonicalize_axis, safe_map, safe_zip @@ -256,10 +257,10 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta): def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False -lax_numpy._set_device_array_base_attributes(PRNGKeyArray, include=[ +_set_device_array_base_attributes(PRNGKeyArray, include=[ '__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape', 'transpose', 'flatten', 'T']) -lax_numpy._register_stackable(PRNGKeyArray) +_register_stackable(PRNGKeyArray) basearray.Array.register(PRNGKeyArray) diff --git a/jax/_src/random.py b/jax/_src/random.py index 9b30a3de1..cf26e6783 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -37,9 +37,8 @@ from jax._src.core import NamedShape from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.lax import lax as lax_internal -from jax._src.numpy.lax_numpy import ( - _arraylike, _check_arraylike, _convert_and_clip_integer, - _promote_dtypes_inexact) +from jax._src.numpy.lax_numpy import _convert_and_clip_integer +from jax._src.numpy.util import _arraylike, _check_arraylike, _promote_dtypes_inexact from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index ec57a10af..907461d5d 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -1,4 +1,4 @@ -# Copyright 2022 The JAX Authors. +# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,8 @@ import scipy.cluster.vq import textwrap from jax import vmap +import jax.numpy as jnp from jax._src.numpy.util import _wraps, _check_arraylike, _promote_dtypes_inexact -from jax._src.numpy.lax_numpy import argmin -from jax._src.numpy.linalg import norm _no_chkfinite_doc = textwrap.dedent(""" @@ -41,7 +40,7 @@ def vq(obs, code_book, check_finite=True): raise ValueError("ndim different than 1 or 2 are not supported") # explicitly rank promotion - dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs) - code = argmin(dist, axis=-1) + dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - code_book, axis=-1))(obs) + code = jnp.argmin(dist, axis=-1) dist_min = vmap(operator.getitem)(dist, code) return code, dist_min diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 25a38f89c..0075533df 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -22,15 +22,15 @@ import warnings from typing import cast, overload, Any, Literal, Optional, Tuple, Union import jax +import jax.numpy as jnp from jax import jit, vmap, jvp from jax import lax from jax._src import dtypes from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh -from jax._src.numpy.lax_numpy import _check_arraylike -from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy import linalg as np_linalg +from jax._src.numpy.util import ( + _check_arraylike, _wraps, _promote_dtypes, _promote_dtypes_inexact, + _promote_dtypes_complex) from jax._src.typing import Array, ArrayLike @@ -125,7 +125,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: del overwrite_a, check_finite # unused - return np_linalg.det(a) + return jnp.linalg.det(a) @overload @@ -210,7 +210,7 @@ def schur(a: ArrayLike, output: str = 'real') -> Tuple[Array, Array]: lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: del overwrite_a, check_finite # unused - return np_linalg.inv(a) + return jnp.linalg.inv(a) @_wraps(scipy.linalg.lu_factor, @@ -332,7 +332,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " @partial(jit, static_argnames=('assume_a', 'lower')) def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: if assume_a != 'pos': - return np_linalg.solve(a, b) + return jnp.linalg.solve(a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) @@ -453,7 +453,7 @@ def _calc_P_Q(A: ArrayLike) -> Tuple[Array, Array, Array]: A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') - A_L1 = np_linalg.norm(A,1) + A_L1 = jnp.linalg.norm(A,1) n_squarings: Array U: Array V: Array @@ -484,7 +484,7 @@ def _solve_P_Q(P: ArrayLike, Q: ArrayLike, upper_triangular: bool = False) -> Ar if upper_triangular: return solve_triangular(Q, P) else: - return np_linalg.solve(Q, P) + return jnp.linalg.solve(Q, P) def _precise_dot(A: ArrayLike, B: ArrayLike) -> Array: return jnp.dot(A, B, precision=lax.Precision.HIGHEST) @@ -609,7 +609,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None, def block_diag(*arrs: ArrayLike) -> Array: if len(arrs) == 0: arrs = cast(Tuple[ArrayLike], (jnp.zeros((1, 0)),)) - arrs = cast(Tuple[ArrayLike], jnp._promote_dtypes(*arrs)) + arrs = cast(Tuple[ArrayLike], _promote_dtypes(*arrs)) bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2] if bad_shapes: raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " @@ -948,8 +948,8 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra return T, Z def _update_T_Z(m, T, Z): - mu = np_linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m] - r = np_linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype) + mu = jnp.linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m] + r = jnp.linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype) c = mu[0] / r s = T[m, m-1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index 557912962..b0ef34219 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -24,7 +24,7 @@ import scipy.ndimage from jax._src import api from jax._src import util from jax import lax -from jax._src.numpy import lax_numpy as jnp +import jax.numpy as jnp from jax._src.numpy.util import _wraps from jax._src.typing import ArrayLike, Array from jax._src.util import safe_zip as zip diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 441c2c56b..81e09175c 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -22,13 +22,13 @@ import scipy.signal as osp_signal import jax import jax.numpy.fft +import jax.numpy as jnp from jax import lax from jax._src import dtypes from jax._src.lax.lax import PrecisionLike -from jax._src.numpy.lax_numpy import _check_arraylike -from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import linalg -from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex +from jax._src.numpy.util import ( + _check_arraylike, _wraps, _promote_dtypes_inexact, _promote_dtypes_complex) from jax._src.third_party.scipy import signal_helper from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index a62a28681..31ca115ee 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -29,8 +29,7 @@ from jax._src import api from jax._src import core from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.lax_numpy import ( - moveaxis, _promote_args_inexact, _promote_dtypes_inexact) +from jax._src.numpy.util import _promote_args_inexact, _promote_dtypes_inexact from jax._src.numpy.util import _wraps from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl @@ -725,7 +724,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: bessel_jn_fun = partial(_bessel_jn, v=v, n_iter=n_iter) for _ in range(z.ndim): bessel_jn_fun = vmap(bessel_jn_fun) - return moveaxis(bessel_jn_fun(z), -1, 0) + return jnp.moveaxis(bessel_jn_fun(z), -1, 0) def _gen_recurrence_mask( diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 81925fb4c..f03253e91 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -22,8 +22,7 @@ import jax.numpy as jnp from jax import jit from jax._src import dtypes from jax._src.api import vmap -from jax._src.numpy.lax_numpy import _check_arraylike -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _check_arraylike, _wraps from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index bdf828a6d..357f99661 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -16,16 +16,16 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, xlog1py @_wraps(osp_stats.bernoulli.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) + k, p, loc = _promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = _lax_const(k, 0) one = _lax_const(k, 1) x = lax.sub(k, loc) @@ -39,7 +39,7 @@ def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: @_wraps(osp_stats.bernoulli.cdf, update_doc=False) def cdf(k: ArrayLike, p: ArrayLike) -> Array: - k, p = jnp._promote_args_inexact('bernoulli.cdf', k, p) + k, p = _promote_args_inexact('bernoulli.cdf', k, p) zero, one = _lax_const(k, 0), _lax_const(k, 1) conds = [ jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one), @@ -52,7 +52,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array: @_wraps(osp_stats.bernoulli.ppf, update_doc=False) def ppf(q: ArrayLike, p: ArrayLike) -> Array: - q, p = jnp._promote_args_inexact('bernoulli.ppf', q, p) + q, p = _promote_args_inexact('bernoulli.ppf', q, p) zero, one = _lax_const(q, 0), _lax_const(q, 1) return jnp.where( jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one), diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index d4788ec1a..85cac2f7c 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -15,9 +15,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import betaln, xlogy, xlog1py @@ -32,8 +32,8 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) - return where(logical_or(lax.gt(x, lax.add(loc, scale)), - lax.lt(x, loc)), -inf, log_probs) + return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)), + lax.lt(x, loc)), -jnp.inf, log_probs) @_wraps(osp_stats.beta.pdf, update_doc=False) diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index a554a4359..ff0b40ac8 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -16,9 +16,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.scipy.special import betaln from jax._src.typing import Array, ArrayLike @@ -34,10 +34,10 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one)))) beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b)) log_probs = lax.add(combiln, beta_lns) - y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) - log_probs = where(y_cond, -inf, log_probs) - n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) - return where(n_a_b_cond, nan, log_probs) + y_cond = jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) + log_probs = jnp.where(y_cond, -jnp.inf, log_probs) + n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) + return jnp.where(n_a_b_cond, jnp.nan, log_probs) @_wraps(osp_stats.betabinom.pmf, update_doc=False) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 5e5044f74..aceb73149 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -19,7 +19,7 @@ import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 2840bc183..e7ce803ad 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -16,9 +16,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -35,7 +35,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return where(lax.lt(x, loc), -inf, log_probs) + return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) @_wraps(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index ac2094592..36284537f 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -16,11 +16,10 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _promote_dtypes_inexact, _wraps from jax.scipy.special import gammaln, xlogy -from jax._src.numpy.lax_numpy import _promote_dtypes_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index 35e2830ca..07dad9ff3 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -15,8 +15,8 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf +import jax.numpy as jnp +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -26,7 +26,7 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: log_scale = lax.log(scale) linear_term = lax.div(lax.sub(x, loc), scale) log_probs = lax.neg(lax.add(linear_term, log_scale)) - return where(lax.lt(x, loc), -inf, log_probs) + return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) @_wraps(osp_stats.expon.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 94ca30c49..a4da3748d 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -15,9 +15,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import gammaln, xlogy @@ -30,7 +30,7 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) - return where(lax.lt(x, loc), -inf, log_probs) + return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) @_wraps(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/scipy/stats/gennorm.py b/jax/_src/scipy/stats/gennorm.py index 5d5234de8..be4cc202e 100644 --- a/jax/_src/scipy/stats/gennorm.py +++ b/jax/_src/scipy/stats/gennorm.py @@ -14,8 +14,7 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 8bf885d84..b5ae60771 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -15,16 +15,16 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax.scipy.special import xlog1py from jax._src.typing import Array, ArrayLike @_wraps(osp_stats.geom.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) + k, p, loc = _promote_args_inexact("geom.logpmf", k, p, loc) zero = _lax_const(k, 0) one = _lax_const(k, 1) x = lax.sub(k, loc) diff --git a/jax/_src/scipy/stats/kde.py b/jax/_src/scipy/stats/kde.py index 2276c4624..0f5234f61 100644 --- a/jax/_src/scipy/stats/kde.py +++ b/jax/_src/scipy/stats/kde.py @@ -21,8 +21,7 @@ import scipy.stats as osp_stats import jax.numpy as jnp from jax import jit, lax, random, vmap -from jax._src.numpy.lax_numpy import _check_arraylike, _promote_dtypes_inexact -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps from jax._src.tree_util import register_pytree_node_class from jax.scipy import linalg, special diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index f954f85d5..4729c9319 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -16,8 +16,7 @@ import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index e021d14d2..46f899778 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -16,10 +16,9 @@ import scipy.stats as osp_stats from jax.scipy.special import expit, logit from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact -from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/multinomial.py b/jax/_src/scipy/stats/multinomial.py index 49fa595be..1394826ad 100644 --- a/jax/_src/scipy/stats/multinomial.py +++ b/jax/_src/scipy/stats/multinomial.py @@ -15,7 +15,7 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy import lax_numpy as jnp +import jax.numpy as jnp from jax._src.numpy.util import _wraps, _promote_args_inexact, _promote_args_numeric from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index d74c96882..475ee7d29 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -19,8 +19,7 @@ import scipy.stats as osp_stats from jax import lax from jax import numpy as jnp -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_dtypes_inexact +from jax._src.numpy.util import _wraps, _promote_dtypes_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/nbinom.py b/jax/_src/scipy/stats/nbinom.py index fb1c5d3a3..55f874f58 100644 --- a/jax/_src/scipy/stats/nbinom.py +++ b/jax/_src/scipy/stats/nbinom.py @@ -16,9 +16,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike @@ -34,7 +34,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra ) log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) log_probs = lax.add(comb_term, log_linear_term) - return where(lax.lt(k, loc), -inf, log_probs) + return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) @_wraps(osp_stats.nbinom.pmf, update_doc=False) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index d74c47132..47e35bc44 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -18,10 +18,9 @@ import numpy as np import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index 7876992fd..0915f0421 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -16,9 +16,9 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, inf, where +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -29,7 +29,7 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) - return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs) + return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs) @_wraps(osp_stats.pareto.pdf, update_doc=False) def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index e135cc80a..de3f41c76 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -16,16 +16,16 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, gammaln, gammaincc @_wraps(osp_stats.poisson.logpmf, update_doc=False) def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: - k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) + k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) log_probs = xlogy(x, mu) - gammaln(x + 1) - mu @@ -37,7 +37,7 @@ def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: @_wraps(osp_stats.poisson.cdf, update_doc=False) def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: - k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) + k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index aff476cfd..ff06f0e4f 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -18,8 +18,7 @@ import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py index 9838c9016..0f414a1a8 100644 --- a/jax/_src/scipy/stats/truncnorm.py +++ b/jax/_src/scipy/stats/truncnorm.py @@ -16,9 +16,8 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact +import jax.numpy as jnp +from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.scipy.stats import norm from jax._src.scipy.special import logsumexp, log_ndtr, ndtr diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index de8084a53..72fdb85df 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -16,9 +16,9 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy.util import _wraps -from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or +from jax.numpy import where, inf, logical_or from jax._src.typing import Array, ArrayLike +from jax._src.numpy.util import _wraps, _promote_args_inexact @_wraps(osp_stats.uniform.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/vonmises.py b/jax/_src/scipy/stats/vonmises.py index f8c9115c4..fc7b10579 100644 --- a/jax/_src/scipy/stats/vonmises.py +++ b/jax/_src/scipy/stats/vonmises.py @@ -15,8 +15,8 @@ import scipy.stats as osp_stats from jax import lax +import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.util import _wraps, _promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index dc3ee4211..05c9185be 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -46,7 +46,7 @@ from jax._src import dtypes as _dtypes from jax._src.config import (flags, bool_env, config, raise_persistent_cache_errors, persistent_cache_min_compile_time_secs) -from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact +from jax._src.numpy.util import _promote_dtypes, _promote_dtypes_inexact from jax._src.util import unzip2 from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, diff --git a/jax/_src/third_party/numpy/linalg.py b/jax/_src/third_party/numpy/linalg.py index c59f85eb9..f907b5a06 100644 --- a/jax/_src/third_party/numpy/linalg.py +++ b/jax/_src/third_party/numpy/linalg.py @@ -1,7 +1,7 @@ import numpy as np -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy import linalg as la +import jax.numpy as jnp +import jax.numpy.linalg as la from jax._src.numpy.util import _check_arraylike, _wraps diff --git a/jax/_src/third_party/scipy/betaln.py b/jax/_src/third_party/scipy/betaln.py index 834310d3f..482aa3b73 100644 --- a/jax/_src/third_party/scipy/betaln.py +++ b/jax/_src/third_party/scipy/betaln.py @@ -1,7 +1,7 @@ from jax import lax import jax.numpy as jnp from jax._src.typing import Array, ArrayLike -from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.util import _promote_args_inexact # Note: for mysterious reasons, annotating this leads to very slow mypy runs. # def algdiv(a: ArrayLike, b: ArrayLike) -> Array: diff --git a/jax/_src/third_party/scipy/interpolate.py b/jax/_src/third_party/scipy/interpolate.py index 978efccd9..3c240bfac 100644 --- a/jax/_src/third_party/scipy/interpolate.py +++ b/jax/_src/third_party/scipy/interpolate.py @@ -1,11 +1,10 @@ from itertools import product import scipy.interpolate as osp_interpolate +from jax.numpy import (asarray, broadcast_arrays, can_cast, + empty, nan, searchsorted, where, zeros) from jax._src.tree_util import register_pytree_node -from jax._src.numpy.lax_numpy import (_check_arraylike, _promote_dtypes_inexact, - asarray, broadcast_arrays, can_cast, - empty, nan, searchsorted, where, zeros) -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps def _ndim_coords_from_arrays(points, ndim=None): diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index d56b9ba47..c32e90df8 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -3,7 +3,7 @@ from typing import Callable, Tuple import scipy.linalg from jax import jit, lax -from jax._src.numpy import lax_numpy as jnp +import jax.numpy as jnp from jax._src.numpy.linalg import norm from jax._src.numpy.util import _wraps from jax._src.scipy.linalg import rsf2csf, schur diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 1f9d0995d..713234694 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -4,7 +4,7 @@ import scipy.signal as osp_signal from typing import Any, Optional, Tuple, Union import warnings -from jax._src.numpy import lax_numpy as jnp +import jax.numpy as jnp from jax._src.typing import Array, ArrayLike, DTypeLike diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c6658c490..6b6674caa 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -34,7 +34,7 @@ from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib.mlir.dialects import hlo from jax._src.lib import gpu_sparse -from jax._src.numpy.lax_numpy import _promote_dtypes +from jax._src.numpy.util import _promote_dtypes from jax._src.typing import Array, ArrayLike, DTypeLike import jax.numpy as jnp diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 76fed247d..f2f10dad9 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -34,7 +34,7 @@ from jax._src import dispatch from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib import gpu_sparse -from jax._src.numpy.lax_numpy import _promote_dtypes +from jax._src.numpy.util import _promote_dtypes from jax._src.typing import Array, ArrayLike, DTypeLike import jax.numpy as jnp