From 33b989ac9e52329722765e3ab3eac3dcbfcbaba8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 14 Feb 2025 11:22:18 -0800 Subject: [PATCH] refactor: import numpy objects directly in jax.numpy --- jax/_src/lax/eigh.py | 8 +- jax/_src/numpy/lax_numpy.py | 217 ++++++++++++++++-------------------- jax/_src/numpy/linalg.py | 30 ++--- jax/_src/numpy/vectorize.py | 4 +- jax/_src/ops/scatter.py | 4 +- jax/numpy/__init__.py | 32 +++--- 6 files changed, 137 insertions(+), 158 deletions(-) diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index 6cc4c1905..99711dc6b 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -61,7 +61,7 @@ def _mask(x, dims, alternative=0): Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert jnp.ndim(x) == len(dims) + assert np.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: @@ -145,7 +145,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False): N, _ = P.shape negative_column_norms = -jnp_linalg.norm(P, axis=1) # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. - negative_column_norms = _mask(negative_column_norms, (n,), jnp.nan) + negative_column_norms = _mask(negative_column_norms, (n,), np.nan) sort_idxs = jnp.argsort(negative_column_norms) X = P[:, sort_idxs] # X = X[:, :rank] @@ -397,7 +397,7 @@ def _eigh_work(H, n, termination_size, subset_by_index): def default_case(agenda, blocks, eigenvectors): V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # TODO: Improve this? - split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) + split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), np.nan)) H_minus, V_minus, H_plus, V_plus, rank = split_spectrum( H, b, split_point, V0=V) @@ -564,7 +564,7 @@ def eigh( eig_vals, eig_vecs = _eigh_work( H, n, termination_size=termination_size, subset_by_index=subset_by_index ) - eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan) + eig_vals = _mask(ufuncs.real(eig_vals), (n,), np.nan) if sort_eigenvalues or compute_slice: sort_idxs = jnp.argsort(eig_vals) if compute_slice: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4cd57913b..6bbef9b2a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -82,18 +82,8 @@ for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: else: break -newaxis = None T = TypeVar('T') - -# NumPy constants - -pi = np.pi -e = np.e -euler_gamma = np.euler_gamma -inf = np.inf -nan = np.nan - # Wrappers for NumPy printoptions def get_printoptions(): @@ -169,9 +159,6 @@ def iscomplexobj(x: Any) -> bool: typ = asarray(x).dtype.type return issubdtype(typ, np.complexfloating) -shape = _shape = np.shape -ndim = _ndim = np.ndim -size = np.size def _dtype(x: Any) -> DType: return dtypes.dtype(x, canonicalize=True) @@ -180,19 +167,11 @@ def _dtype(x: Any) -> DType: iinfo = dtypes.iinfo finfo = dtypes.finfo -dtype = np.dtype can_cast = dtypes.can_cast promote_types = dtypes.promote_types ComplexWarning = NumpyComplexWarning -# Numpy functions -array_str = np.array_str -array_repr = np.array_repr - -save = np.save -savez = np.savez - _lax_const = lax_internal._const @@ -534,8 +513,6 @@ def isscalar(element: Any) -> bool: return asarray(element).ndim == 0 return False -iterable = np.iterable - @export def result_type(*args: Any) -> DType: @@ -621,7 +598,7 @@ def trunc(x: ArrayLike) -> Array: @partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, preferred_element_type: DTypeLike | None = None) -> Array: - if ndim(x) != 1 or ndim(y) != 1: + if np.ndim(x) != 1 or np.ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") if preferred_element_type is None: # if unspecified, promote to inexact following NumPy's default for convolutions. @@ -856,7 +833,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, util.check_arraylike("histogram_bin_edges", a, bins) arr = asarray(a) dtype = dtypes.to_inexact_dtype(arr.dtype) - if _ndim(bins) == 1: + if np.ndim(bins) == 1: return asarray(bins, dtype=dtype) bins_int = core.concrete_or_error(operator.index, bins, @@ -864,7 +841,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, if range is None: range = [arr.min(), arr.max()] range = asarray(range, dtype=dtype) - if shape(range) != (2,): + if np.shape(range) != (2,): raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}") range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]), where(reductions.ptp(range) == 0, range[1] + 0.5, range[1])) @@ -940,7 +917,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, weights = ones_like(a) else: util.check_arraylike("histogram", a, bins, weights) - if shape(a) != shape(weights): + if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -1105,13 +1082,13 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, sample, = util.promote_dtypes_inexact(sample) else: util.check_arraylike("histogramdd", sample, weights) - if shape(weights) != shape(sample)[:1]: + if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) - N, D = shape(sample) + N, D = np.shape(sample) if range is not None and ( - len(range) != D or any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type] + len(range) != D or any(r is not None and np.shape(r)[0] != 2 for r in range)): # type: ignore[arg-type] raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence " f"of {D} pairs or Nones; got {range=}") @@ -1228,8 +1205,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: [2, 4]], dtype=int32) """ util.check_arraylike("transpose", a) - axes_ = list(range(ndim(a))[::-1]) if axes is None else axes - axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] + axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1383,8 +1360,8 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: f"two, but got first argument of shape {np.shape(m)}, " f"which has ndim {np.ndim(m)}") ax1, ax2 = axes - ax1 = _canonicalize_axis(ax1, ndim(m)) - ax2 = _canonicalize_axis(ax2, ndim(m)) + ax1 = _canonicalize_axis(ax1, np.ndim(m)) + ax2 = _canonicalize_axis(ax2, np.ndim(m)) if ax1 == ax2: raise ValueError("Axes must be different") # same as numpy error k = k % 4 @@ -1393,7 +1370,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: elif k == 2: return flip(flip(m, ax1), ax2) else: - perm = list(range(ndim(m))) + perm = list(range(np.ndim(m))) perm[ax1], perm[ax2] = perm[ax2], perm[ax1] if k == 1: return transpose(flip(m, ax2), perm) @@ -1464,9 +1441,9 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: @partial(jit, static_argnames=('axis',)) def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: if axis is None: - return lax.rev(m, list(range(len(shape(m))))) + return lax.rev(m, list(range(len(np.shape(m))))) axis = _ensure_index_tuple(axis) - return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) + return lax.rev(m, [_canonicalize_axis(ax, np.ndim(m)) for ax in axis]) @export @@ -1617,7 +1594,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: im = ufuncs.imag(z) dtype = _dtype(re) if not issubdtype(dtype, np.inexact) or ( - issubdtype(_dtype(z), np.floating) and ndim(z) == 0): + issubdtype(_dtype(z), np.floating) and np.ndim(z) == 0): dtype = dtypes.canonicalize_dtype(dtypes.float_) re = lax.convert_element_type(re, dtype) im = lax.convert_element_type(im, dtype) @@ -1704,7 +1681,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, combined: list[Array] = [] if prepend is not None: prepend = util.ensure_arraylike("diff", prepend) - if not ndim(prepend): + if not np.ndim(prepend): shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) @@ -1714,7 +1691,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, if append is not None: append = util.ensure_arraylike("diff", append) - if not ndim(append): + if not np.ndim(append): shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) @@ -1878,12 +1855,12 @@ def gradient( upper_edge = sliced(1, 2) - sliced(0, 1) lower_edge = sliced(-1, None) - sliced(-2, -1) - if ndim(h) == 0: + if np.ndim(h) == 0: inner = (sliced(2, None) - sliced(None, -2)) * 0.5 / h lower_edge /= h upper_edge /= h - elif ndim(h) == 1: + elif np.ndim(h) == 1: if len(h) != a.shape[axis]: raise ValueError( "Spacing arrays must have the same length as the " @@ -2112,7 +2089,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: util.check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") - return reshape(a, (size(a),), order) + return reshape(a, (np.size(a),), order) @export @@ -2259,7 +2236,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: # TODO: Consider warning here since shape is supposed to be a sequence, so # this should not happen. shape = [shape] - if any(ndim(s) != 0 for s in shape): + if any(np.ndim(s) != 0 for s in shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") out_indices: list[ArrayLike] = [0] * len(shape) for i, s in reversed(list(enumerate(shape))): @@ -2385,7 +2362,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: @partial(jit, static_argnames=('axis',), inline=True) def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: if axis is None: - a_shape = shape(a) + a_shape = np.shape(a) if not core.is_constant_shape(a_shape): # We do not even know the rank of the output if the input shape is not known raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism") @@ -2507,7 +2484,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: (2, 5, 4, 3) """ util.check_arraylike("swapaxes", a) - perm = np.arange(ndim(a)) + perm = np.arange(np.ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) @@ -2567,12 +2544,12 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int], @partial(jit, static_argnames=('source', 'destination'), inline=True) def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array: - source = tuple(_canonicalize_axis(i, ndim(a)) for i in source) - destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination) + source = tuple(_canonicalize_axis(i, np.ndim(a)) for i in source) + destination = tuple(_canonicalize_axis(i, np.ndim(a)) for i in destination) if len(source) != len(destination): raise ValueError("Inconsistent number of elements: {} vs {}" .format(len(source), len(destination))) - perm = [i for i in range(ndim(a)) if i not in source] + perm = [i for i in range(np.ndim(a)) if i not in source] for dest, src in sorted(zip(destination, source)): perm.insert(dest, src) return lax.transpose(a, perm) @@ -2666,7 +2643,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: util.check_arraylike("interp", x, xp, fp) - if shape(xp) != shape(fp) or ndim(xp) != 1: + if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") x_arr, xp_arr = util.promote_dtypes_inexact(x, xp) fp_arr, = util.promote_dtypes_inexact(fp) @@ -2691,7 +2668,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, raise ValueError("jnp.interp: complex x values not supported.") if period is not None: - if ndim(period) != 0: + if np.ndim(period) != 0: raise ValueError(f"period must be a scalar; got {period}") period = ufuncs.abs(period) x_arr = x_arr % period @@ -3018,7 +2995,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), np.integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") - if ndim(x) != 1: + if np.ndim(x) != 1: raise ValueError("only 1-dimensional input supported.") minlength = core.concrete_or_error(operator.index, minlength, "The error occurred because of argument 'minlength' of jnp.bincount.") @@ -3032,7 +3009,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, "The error occurred because of argument 'length' of jnp.bincount.") if weights is None: weights = np.array(1, dtype=dtypes.int_) - elif shape(x) != shape(weights): + elif np.shape(x) != np.shape(weights): raise ValueError("shape of weights must match shape of x.") return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop') @@ -3789,7 +3766,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, """ arr = util.ensure_arraylike("nonzero", a) del a - if ndim(arr) == 0: + if np.ndim(arr) == 0: raise ValueError("Calling nonzero on 0d arrays is not allowed. " "Use jnp.atleast_1d(scalar).nonzero() instead.") mask = arr if arr.dtype == bool else (arr != 0) @@ -3805,7 +3782,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if fill_value is not None: fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,) - if any(_shape(val) != () for val in fill_value_tup): + if any(np.shape(val) != () for val in fill_value_tup): raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}") fill_mask = arange(calculated_size) >= mask.sum() out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) @@ -3861,7 +3838,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, @export @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, - axis: int = -1, period: ArrayLike = 2 * pi) -> Array: + axis: int = -1, period: ArrayLike = 2 * np.pi) -> Array: """Unwrap a periodic signal. JAX implementation of :func:`numpy.unwrap`. @@ -3997,10 +3974,10 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: - nd = ndim(array) + nd = np.ndim(array) constant_values = lax_internal._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) - constant_values_nd = ndim(constant_values) + constant_values_nd = np.ndim(constant_values) if constant_values_nd == 0: widths = [(low, high, 0) for (low, high) in pad_width] @@ -4033,7 +4010,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array: - for i in range(ndim(array)): + for i in range(np.ndim(array)): if array.shape[i] == 0: _check_no_padding(pad_width[i], "wrap") continue @@ -4056,7 +4033,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int], assert mode in ("symmetric", "reflect") assert reflect_type in ("even", "odd") - for i in range(ndim(array)): + for i in range(np.ndim(array)): if array.shape[i] == 0: _check_no_padding(pad_width[i], mode) continue @@ -4121,7 +4098,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int], def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array: - nd = ndim(array) + nd = np.ndim(array) for i in range(nd): if array.shape[i] == 0: _check_no_padding(pad_width[i], "edge") @@ -4142,7 +4119,7 @@ def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array: def _pad_linear_ramp(array: Array, pad_width: PadValue[int], end_values: PadValue[ArrayLike]) -> Array: - for axis in range(ndim(array)): + for axis in range(np.ndim(array)): edge_before = lax.slice_in_dim(array, 0, 1, axis=axis) edge_after = lax.slice_in_dim(array, -1, None, axis=axis) ramp_before = linspace( @@ -4176,7 +4153,7 @@ def _pad_linear_ramp(array: Array, pad_width: PadValue[int], def _pad_stats(array: Array, pad_width: PadValue[int], stat_length: PadValue[int] | None, stat_func: PadStatFunc) -> Array: - nd = ndim(array) + nd = np.ndim(array) for i in range(nd): if stat_length is None: stat_before = stat_func(array, axis=i, keepdims=True) @@ -4215,7 +4192,7 @@ def _pad_stats(array: Array, pad_width: PadValue[int], def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array: # Note: jax.numpy.empty = jax.numpy.zeros - for i in range(ndim(array)): + for i in range(np.ndim(array)): shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:] pad_before = empty_like(array, shape=shape_before) @@ -4226,9 +4203,9 @@ def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array: def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], **kwargs) -> Array: - pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") + pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") padded = _pad_constant(array, pad_width, asarray(0)) - for axis in range(ndim(padded)): + for axis in range(np.ndim(padded)): padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs) return padded @@ -4238,7 +4215,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, constant_values: ArrayLike, stat_length: PadValueLike[int], end_values: PadValueLike[ArrayLike], reflect_type: str): array = asarray(array) - nd = ndim(array) + nd = np.ndim(array) if nd == 0: return array @@ -4406,7 +4383,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], """ util.check_arraylike("pad", array) - pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") + pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): raise TypeError('`pad_width` must be of integral type.') @@ -4501,11 +4478,11 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: util.check_arraylike("stack", *arrays) - shape0 = shape(arrays[0]) + shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] for a in arrays: - if shape(a) != shape0: + if np.shape(a) != shape0: raise ValueError("All input arrays must have the same shape.") new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) @@ -4598,7 +4575,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: reps_tup = tuple(reps) # type: ignore[arg-type] reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep for rep in reps_tup) - A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A) + A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A) reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]), [k for pair in zip(reps_tup, A_shape) for k in pair]) @@ -4667,9 +4644,9 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], raise ValueError("Need at least one array to concatenate.") if axis is None: return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype) - if ndim(arrays[0]) == 0: + if np.ndim(arrays[0]) == 0: raise ValueError("Zero-dimensional arrays cannot be concatenated.") - axis = _canonicalize_axis(axis, ndim(arrays[0])) + axis = _canonicalize_axis(axis, np.ndim(arrays[0])) if dtype is None: arrays_out = util.promote_dtypes(*arrays) else: @@ -5074,7 +5051,7 @@ def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], def _atleast_nd(x: ArrayLike, n: int) -> Array: - m = ndim(x) + m = np.ndim(x) return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x) def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: @@ -5087,7 +5064,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: xs_tup, depths = unzip2([_block(x) for x in xs]) if any(d != depths[0] for d in depths[1:]): raise ValueError("Mismatched list depths in jax.numpy.block") - rank = max(depths[0], max(ndim(x) for x in xs_tup)) + rank = max(depths[0], max(np.ndim(x) for x in xs_tup)) xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup) return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1 else: @@ -5589,8 +5566,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, raise TypeError(f"Unexpected input type for array: {type(object)}") out_array: Array = lax_internal._convert_element_type( out, dtype, weak_type=weak_type, sharding=sharding) - if ndmin > ndim(out_array): - out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array))) + if ndmin > np.ndim(out_array): + out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) return out_array @@ -5839,7 +5816,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: Array(True, dtype=bool) """ a1, a2 = asarray(a1), asarray(a2) - if shape(a1) != shape(a2): + if np.shape(a1) != np.shape(a2): return array(False, dtype=bool) eq = asarray(a1 == a2) if equal_nan: @@ -6519,7 +6496,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, start = start.astype(computation_dtype) stop = stop.astype(computation_dtype) - bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) + bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) broadcast_start = broadcast_to(start, bounds_shape) broadcast_stop = broadcast_to(stop, bounds_shape) axis = len(bounds_shape) + axis + 1 if axis < 0 else axis @@ -6542,12 +6519,12 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, _canonicalize_axis(axis, out.ndim)) elif num == 1: - delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype) + delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) out = reshape(broadcast_start, bounds_shape) else: # num == 0 degenerate case, match numpy behavior - empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) + empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) empty_shape.insert(axis, 0) - delta = asarray(nan, dtype=computation_dtype) + delta = asarray(np.nan, dtype=computation_dtype) out = reshape(array([], dtype=dtype), empty_shape) if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer): @@ -7053,7 +7030,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, "value to `total_repeat_length`.") # Fast path for when repeats is a scalar. - if np.ndim(repeats) == 0 and ndim(arr) != 0: + if np.ndim(repeats) == 0 and np.ndim(arr) != 0: input_shape = arr.shape axis = _canonicalize_axis(axis, len(input_shape)) aux_axis = axis + 1 @@ -7076,7 +7053,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, # Special case when a is a scalar. if arr.ndim == 0: - if shape(repeats) == (1,): + if np.shape(repeats) == (1,): return full([total_repeat_length], arr) else: raise ValueError('`repeat` with a scalar parameter `a` is only ' @@ -7279,7 +7256,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: [7, 8]]], dtype=int32) """ util.check_arraylike("tril", m) - m_shape = shape(m) + m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") N, M = m_shape[-2:] @@ -7346,7 +7323,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: [0, 8]]], dtype=int32) """ util.check_arraylike("triu", m) - m_shape = shape(m) + m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") N, M = m_shape[-2:] @@ -7406,12 +7383,12 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") - if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)): + if _canonicalize_axis(axis1, np.ndim(a)) == _canonicalize_axis(axis2, np.ndim(a)): raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") dtypes.check_user_dtype_supported(dtype, "trace") - a_shape = shape(a) + a_shape = np.shape(a) a = moveaxis(a, (axis1, axis2), (-2, -1)) # Mask out the diagonal and reduce. @@ -7650,7 +7627,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = shape(arr) + arr_shape = np.shape(arr) if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @@ -7708,7 +7685,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = shape(arr) + arr_shape = np.shape(arr) if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) @@ -7863,12 +7840,12 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32)) """ util.check_arraylike("diag_indices_from", arr) - nd = ndim(arr) - if not ndim(arr) >= 2: + nd = np.ndim(arr) + if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") - s = shape(arr) - if len(set(shape(arr))) != 1: + s = np.shape(arr) + if len(set(np.shape(arr))) != 1: raise ValueError("All dimensions of input must be of equal length") return diag_indices(s[0], ndim=nd) @@ -7913,12 +7890,12 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, """ util.check_arraylike("diagonal", a) - if ndim(a) < 2: + if np.ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()") def _default_diag(a): - a_shape = shape(a) + a_shape = np.shape(a) a = moveaxis(a, (axis1, axis2), (-2, -1)) @@ -7932,10 +7909,10 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, # The mosaic lowering rule for diag is only defined for square arrays. # TODO(mvoz): Add support for offsets. - if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool: + if np.shape(a)[0] != np.shape(a)[1] or np.ndim(a) != 2 or offset != 0 or _dtype(a) == bool: return _default_diag(a) else: - a_shape_eye = eye(shape(a)[0], dtype=_dtype(a)) + a_shape_eye = eye(np.shape(a)[0], dtype=_dtype(a)) def _mosaic_diag(a): def _sum(x, axis): @@ -8002,7 +7979,7 @@ def diag(v: ArrayLike, k: int = 0) -> Array: @partial(jit, static_argnames=('k',)) def _diag(v, k): util.check_arraylike("diag", v) - v_shape = shape(v) + v_shape = np.shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) n = v_shape[0] + abs(k) @@ -8472,7 +8449,7 @@ def apply_along_axis( Array([ 65, 133, 243], dtype=int32) """ util.check_arraylike("apply_along_axis", arr) - num_dims = ndim(arr) + num_dims = np.ndim(arr) axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) for i in range(1, num_dims - axis): @@ -8675,13 +8652,13 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: """ util.check_arraylike("kron", a, b) a, b = util.promote_dtypes(a, b) - if ndim(a) < ndim(b): - a = expand_dims(a, range(ndim(b) - ndim(a))) - elif ndim(b) < ndim(a): - b = expand_dims(b, range(ndim(a) - ndim(b))) - a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2)) - b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2)) - out_shape = tuple(np.multiply(shape(a), shape(b))) + if np.ndim(a) < np.ndim(b): + a = expand_dims(a, range(np.ndim(b) - np.ndim(a))) + elif np.ndim(b) < np.ndim(a): + b = expand_dims(b, range(np.ndim(a) - np.ndim(b))) + a_reshaped = expand_dims(a, range(1, 2 * np.ndim(a), 2)) + b_reshaped = expand_dims(b, range(0, 2 * np.ndim(b), 2)) + out_shape = tuple(np.multiply(np.shape(a), np.shape(b))) return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) @@ -8809,9 +8786,9 @@ def argwhere( Array([], shape=(0, 0), dtype=int32) """ result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) - if ndim(a) == 0: + if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) - return result.reshape(result.shape[0], ndim(a)) + return result.reshape(result.shape[0], np.ndim(a)) @export @@ -8859,7 +8836,7 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None, @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: - dims = list(range(ndim(a))) + dims = list(range(np.ndim(a))) a = ravel(a) axis = 0 else: @@ -8915,7 +8892,7 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None, @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: - dims = list(range(ndim(a))) + dims = list(range(np.ndim(a))) a = ravel(a) axis = 0 else: @@ -8989,7 +8966,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): if not issubdtype(_dtype(a), np.inexact): return argmax(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) - a = where(nan_mask, -inf, a) + a = where(nan_mask, -np.inf, a) res = argmax(a, axis=axis, keepdims=keepdims) return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) @@ -9050,7 +9027,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): if not issubdtype(_dtype(a), np.inexact): return argmin(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) - a = where(nan_mask, inf, a) + a = where(nan_mask, np.inf, a) res = argmin(a, axis=axis, keepdims=keepdims) return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) @@ -9191,7 +9168,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """ util.check_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") - a_ndim = ndim(a) + a_ndim = np.ndim(a) axis = _canonicalize_axis(axis, a_ndim) if not (-a_ndim <= start <= a_ndim): raise ValueError(f"{start=} must satisfy {-a_ndim}<=start<={a_ndim}") @@ -9764,9 +9741,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, w: Array | None = None if fweights is not None: fweights = util.ensure_arraylike("cov", fweights) - if ndim(fweights) > 1: + if np.ndim(fweights) > 1: raise RuntimeError("cannot handle multidimensional fweights") - if shape(fweights)[0] != X.shape[1]: + if np.shape(fweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and fweights") if not issubdtype(_dtype(fweights), np.integer): raise TypeError("fweights must be integer.") @@ -9774,9 +9751,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, w = abs(fweights) if aweights is not None: aweights = util.ensure_arraylike("cov", aweights) - if ndim(aweights) > 1: + if np.ndim(aweights) > 1: raise RuntimeError("cannot handle multidimensional aweights") - if shape(aweights)[0] != X.shape[1]: + if np.shape(aweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and aweights") # Ensure positive aweights: note that numpy raises an error for negative aweights. aweights = abs(aweights) @@ -9877,7 +9854,7 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A """ util.check_arraylike("corrcoef", x) c = cov(x, y, rowvar) - if len(shape(c)) == 0: + if len(np.shape(c)) == 0: # scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise return ufuncs.divide(c, c) d = diag(c) @@ -10002,7 +9979,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', raise ValueError( f"{method!r} is an invalid value for keyword 'method'. " "Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].") - if ndim(a) != 1: + if np.ndim(a) != 1: raise ValueError("a should be 1-dimensional") a, v = util.promote_dtypes(a, v) if sorter is not None: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 7429af845..e0b1442d5 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -478,7 +478,7 @@ def _slogdet_lu(a: Array) -> tuple[Array, Array]: jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where( - is_zero, jnp.array(-jnp.inf, dtype=dtype), + is_zero, jnp.array(-np.inf, dtype=dtype), reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1)) return sign, ufuncs.real(logdet) @@ -539,7 +539,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ a = ensure_arraylike("jnp.linalg.slogdet", a) a, = promote_dtypes_inexact(a) - a_shape = jnp.shape(a) + a_shape = np.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": @@ -610,8 +610,8 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]: a, b = ensure_arraylike("jnp.linalg._cofactor_solve", a, b) a, = promote_dtypes_inexact(a) b, = promote_dtypes_inexact(b) - a_shape = jnp.shape(a) - b_shape = jnp.shape(b) + a_shape = np.shape(a) + b_shape = np.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): @@ -710,7 +710,7 @@ def det(a: ArrayLike) -> Array: """ a = ensure_arraylike("jnp.linalg.det", a) a, = promote_dtypes_inexact(a) - a_shape = jnp.shape(a) + a_shape = np.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: return _det_2x2(a) elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: @@ -976,10 +976,10 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) # Singular values less than or equal to ``rtol * largest_singular_value`` # are set to zero. - rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1)) + rtol = lax.expand_dims(rtol[..., np.newaxis], range(s.ndim - rtol.ndim - 1)) cutoff = rtol * s[..., 0:1] - s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) - res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]), + s = jnp.where(s > cutoff, s, np.inf).astype(u.dtype) + res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., np.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) @@ -1148,7 +1148,7 @@ def norm(x: ArrayLike, ord: int | str | None = None, """ x = ensure_arraylike("jnp.linalg.norm", x) x, = promote_dtypes_inexact(x) - x_shape = jnp.shape(x) + x_shape = np.shape(x) ndim = len(x_shape) if axis is None: @@ -1181,12 +1181,12 @@ def norm(x: ArrayLike, ord: int | str | None = None, col_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) - elif ord == jnp.inf: + elif ord == np.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) - elif ord == -jnp.inf: + elif ord == -np.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), @@ -1392,7 +1392,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() safe_s = jnp.where(mask, s, 1).astype(a.dtype) - s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, np.newaxis] uTb = tensor_contractions.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) x = tensor_contractions.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we @@ -1651,9 +1651,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k if ord is None or ord == 2: return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, keepdims=keepdims)) - elif ord == jnp.inf: + elif ord == np.inf: return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == -jnp.inf: + elif ord == -np.inf: return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, @@ -2177,7 +2177,7 @@ def cond(x: ArrayLike, p=None): raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}") r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1)) # Convert NaNs to infs where original array has no NaNs. - return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) + return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), np.inf, r) @export diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index f1e6d399b..e6ad1386a 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -19,6 +19,8 @@ import re from typing import Any import warnings +import numpy as np + from jax._src import api from jax._src import config from jax import lax @@ -140,7 +142,7 @@ def _check_output_dims( """Check that output core dimensions match the signature.""" def wrapped(*args): out = func(*args) - out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out]) + out_shapes = map(np.shape, out if isinstance(out, tuple) else [out]) if expected_output_core_dims is None: output_core_dims = [()] * len(out_shapes) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index ed9a9eb02..e19be6622 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -98,7 +98,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, FutureWarning) idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = indexing.index_to_gather(jnp.shape(x), idx, + indexer = indexing.index_to_gather(np.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and @@ -110,7 +110,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) - # Collapse any `None`/`jnp.newaxis` dimensions. + # Collapse any `None`/`np.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index c1429eae5..d563483a2 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -39,9 +39,7 @@ from jax._src.numpy.lax_numpy import ( array as array, array_equal as array_equal, array_equiv as array_equiv, - array_repr as array_repr, array_split as array_split, - array_str as array_str, astype as astype, asarray as asarray, atleast_1d as atleast_1d, @@ -75,10 +73,7 @@ from jax._src.numpy.lax_numpy import ( digitize as digitize, dsplit as dsplit, dstack as dstack, - dtype as dtype, - e as e, ediff1d as ediff1d, - euler_gamma as euler_gamma, expand_dims as expand_dims, extract as extract, eye as eye, @@ -111,7 +106,6 @@ from jax._src.numpy.lax_numpy import ( identity as identity, iinfo as iinfo, indices as indices, - inf as inf, insert as insert, interp as interp, isclose as isclose, @@ -121,7 +115,6 @@ from jax._src.numpy.lax_numpy import ( isrealobj as isrealobj, isscalar as isscalar, issubdtype as issubdtype, - iterable as iterable, ix_ as ix_, kron as kron, lcm as lcm, @@ -132,17 +125,13 @@ from jax._src.numpy.lax_numpy import ( matrix_transpose as matrix_transpose, meshgrid as meshgrid, moveaxis as moveaxis, - nan as nan, nan_to_num as nan_to_num, nanargmax as nanargmax, nanargmin as nanargmin, - ndim as ndim, - newaxis as newaxis, nonzero as nonzero, packbits as packbits, pad as pad, permute_dims as permute_dims, - pi as pi, piecewise as piecewise, printoptions as printoptions, promote_types as promote_types, @@ -156,13 +145,9 @@ from jax._src.numpy.lax_numpy import ( rollaxis as rollaxis, rot90 as rot90, round as round, - save as save, - savez as savez, searchsorted as searchsorted, select as select, set_printoptions as set_printoptions, - shape as shape, - size as size, split as split, squeeze as squeeze, stack as stack, @@ -277,17 +262,32 @@ from jax._src.numpy.window_functions import ( kaiser as kaiser, ) -# NumPy generic scalar types: +# Some APIs come directly from NumPy: from numpy import ( + array_repr as array_repr, + array_str as array_str, character as character, complexfloating as complexfloating, + dtype as dtype, + e as e, + euler_gamma as euler_gamma, flexible as flexible, floating as floating, generic as generic, inexact as inexact, + inf as inf, integer as integer, + iterable as iterable, + nan as nan, + ndim as ndim, + newaxis as newaxis, number as number, object_ as object_, + pi as pi, + save as save, + savez as savez, + shape as shape, + size as size, signedinteger as signedinteger, unsignedinteger as unsignedinteger, )