From f652b6ad6aa44e586ee8989a39ca95b63205cec3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 06:03:54 -0800 Subject: [PATCH] Set __module__ attribute for objects in jax.numpy --- jax/_src/dtypes.py | 3 + jax/_src/numpy/index_tricks.py | 12 +- jax/_src/numpy/lax_numpy.py | 188 +++++++++++++++++++++++++++++++- jax/_src/numpy/polynomial.py | 14 +++ jax/_src/numpy/reductions.py | 40 ++++++- jax/_src/numpy/setops.py | 14 ++- jax/_src/numpy/ufunc_api.py | 5 +- jax/_src/numpy/ufuncs.py | 114 +++++++++++++++++++ jax/_src/numpy/vectorize.py | 5 +- tests/package_structure_test.py | 11 +- 10 files changed, 396 insertions(+), 10 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f5b0c3fd6..1c5e285ba 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -343,6 +343,7 @@ _types_for_issubdtype = (type, np.dtype, ExtendedDType) # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). +@set_module('jax.numpy') def issubdtype(a: DTypeLike | ExtendedDType | None, b: DTypeLike | ExtendedDType | None) -> bool: """Returns True if first argument is a typecode lower/equal in type hierarchy. @@ -458,6 +459,7 @@ _dtype_kinds: dict[str, set] = { } +@set_module('jax.numpy') def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool: """Returns a boolean indicating whether a provided dtype is of a specified kind. @@ -650,6 +652,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "JAX's internal logic; please report it to the JAX maintainers." ) +@set_module('jax.numpy') def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 90a17000c..ec67d7489 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -24,10 +24,14 @@ from jax._src.numpy.lax_numpy import ( arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module import numpy as np +export = set_module('jax.numpy') + + __all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] @@ -87,7 +91,7 @@ class _Mgrid: return stack(output_arr, 0) -mgrid = _Mgrid() +mgrid = export(_Mgrid()) class _Ogrid: @@ -129,7 +133,7 @@ class _Ogrid: return meshgrid(*output, indexing='ij', sparse=True) -ogrid = _Ogrid() +ogrid = export(_Ogrid()) _IndexType = Union[ArrayLike, str, slice] @@ -279,7 +283,7 @@ class RClass(_AxisConcat): op_name = "r_" -r_ = RClass() +r_ = export(RClass()) class CClass(_AxisConcat): @@ -327,7 +331,7 @@ class CClass(_AxisConcat): op_name = "c_" -c_ = CClass() +c_ = export(CClass()) s_ = np.s_ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4cf37f6f7..4c261d111 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,13 +68,16 @@ from jax._src.typing import ( ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum +export = set_module('jax.numpy') + for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: try: cuda_plugin_extension = importlib.import_module( @@ -116,6 +119,7 @@ get_printoptions = np.get_printoptions printoptions = np.printoptions set_printoptions = np.set_printoptions +@export def iscomplexobj(x: Any) -> bool: """Check if the input is a complex number or an array containing complex elements. @@ -327,6 +331,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) +@export def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: """Load JAX arrays from npy files. @@ -376,6 +381,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> ### implementations of numpy functions in terms of lax +@export @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. @@ -427,6 +433,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise maximum of the input arrays. @@ -476,6 +483,7 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: """Return True if arg1 is equal or lower than arg2 in the type hierarchy. @@ -522,6 +530,7 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) +@export def isscalar(element: Any) -> bool: """Return True if the input is a scalar. @@ -620,6 +629,7 @@ def isscalar(element: Any) -> bool: iterable = np.iterable +@export def result_type(*args: Any) -> DType: """Return the result of applying JAX promotion rules to the inputs. @@ -663,6 +673,7 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) +@export @jit def trunc(x: ArrayLike) -> Array: """Round input to the nearest integer towards zero. @@ -739,6 +750,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, @@ -814,6 +826,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision=precision, preferred_element_type=preferred_element_type) +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, @@ -899,6 +912,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision=precision, preferred_element_type=preferred_element_type) +@export def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: @@ -950,6 +964,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, return linspace(range[0], range[1], bins_int + 1, dtype=dtype) +@export def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, @@ -1031,6 +1046,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, return counts, bin_edges +@export def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1120,6 +1136,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = return hist, edges[0], edges[1] +@export def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1229,6 +1246,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim +@export def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -1307,6 +1325,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) +@export def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: """Permute the axes/dimensions of an array. @@ -1336,6 +1355,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: return lax.transpose(a, axes) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose the last two dimensions of an array. @@ -1389,6 +1409,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) +@export @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. @@ -1472,6 +1493,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) +@export def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Reverse the order of elements of an array along the given axis. @@ -1539,6 +1561,7 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) +@export def fliplr(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 1. @@ -1565,6 +1588,7 @@ def fliplr(m: ArrayLike) -> Array: return _flip(asarray(m), 1) +@export def flipud(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 0. @@ -1590,6 +1614,8 @@ def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0) + +@export @jit def iscomplex(x: ArrayLike) -> Array: """Return boolean array showing where the input is complex. @@ -1613,6 +1639,8 @@ def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) + +@export @jit def isreal(x: ArrayLike) -> Array: """Return boolean array showing where the input is real. @@ -1637,6 +1665,7 @@ def isreal(x: ArrayLike) -> Array: return lax.eq(i, _lax_const(i, 0)) +@export @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: """Return the angle of a complex valued number or array. @@ -1688,6 +1717,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result +@export @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, @@ -1800,6 +1830,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr +@export @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: @@ -1862,6 +1893,8 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result + +@export @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1992,6 +2025,7 @@ def gradient( return a_grad[0] if len(axis_tuple) == 1 else a_grad +@export def isrealobj(x: Any) -> bool: """Check if the input is not a complex number or an array containing complex elements. @@ -2026,6 +2060,7 @@ def isrealobj(x: Any) -> bool: return not iscomplexobj(x) +@export def reshape( a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), @@ -2129,6 +2164,7 @@ def reshape( return asarray(a).reshape(shape, order=order) +@export @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: """Flatten array into a 1-dimensional shape. @@ -2182,6 +2218,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) +@export def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: """Convert multi-dimensional indices into flat indices. @@ -2273,6 +2310,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result +@export def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: """Convert flat indices into multi-dimensional indices. @@ -2336,6 +2374,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: for s, i in safe_zip(shape, out_indices)) +@export @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: """Return a new array with specified shape. @@ -2387,6 +2426,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) +@export def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Remove one or more length-1 axes from array @@ -2457,6 +2497,7 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: return lax.squeeze(a, axis) +@export def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: """Insert dimensions of length 1 into array @@ -2527,6 +2568,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: return lax.expand_dims(a, axis) +@export @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: """Swap two axes of an array. @@ -2574,6 +2616,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: return lax.transpose(a, list(perm)) +@export def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: """Move an array axis to a new position @@ -2639,6 +2682,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - return lax.transpose(a, perm) +@export @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -2783,6 +2827,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f +@export def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, @@ -2865,6 +2910,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, ) -> Array | tuple[Array, ...]: ... +@export def where(condition, x=None, y=None, /, *, size=None, fill_value=None): """Select elements from two arrays based on a condition. @@ -2940,6 +2986,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) +@export def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -3007,6 +3054,7 @@ def select( return lax.select_n(*broadcast_arrays(idx, *choicelist)) +@export def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: @@ -3099,6 +3147,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... +@export def broadcast_shapes(*shapes): """Broadcast input shapes to a common output shape. @@ -3139,6 +3188,7 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) +@export def broadcast_arrays(*args: ArrayLike) -> list[Array]: """Broadcast arrays to a common shape. @@ -3178,6 +3228,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args) +@export def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: """Broadcast an array to a specified shape. @@ -3254,6 +3305,7 @@ def _split(op: str, ary: ArrayLike, for start, end in zip(split_indices[:-1], split_indices[1:])] +@export def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3317,6 +3369,7 @@ def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, return _split("split", ary, indices_or_sections, axis=axis) +@export def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays vertically. @@ -3351,6 +3404,7 @@ def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("vsplit", ary, indices_or_sections, axis=0) +@export def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays horizontally. @@ -3391,6 +3445,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) +@export def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays depth-wise. @@ -3432,6 +3487,7 @@ def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("dsplit", ary, indices_or_sections, axis=2) +@export def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3457,6 +3513,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array return _split("array_split", ary, indices_or_sections, axis=axis) +@export @jit def clip( arr: ArrayLike | None = None, @@ -3528,6 +3585,7 @@ def clip( return asarray(arr) +@export @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. @@ -3599,12 +3657,14 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: return _round_float(a) +@export @partial(jit, static_argnames=('decimals',)) def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Alias of :func:`jax.numpy.round`""" return round(a, decimals, out) +@export @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. @@ -3643,6 +3703,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) +@export @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, @@ -3708,6 +3769,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, return out +@export @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -3756,6 +3818,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) +@export def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: @@ -3863,6 +3926,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return out +@export def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: """Return indices of nonzero elements in a flattened array @@ -3908,6 +3972,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, return nonzero(ravel(a), size=size, fill_value=fill_value)[0] +@export @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: @@ -4337,6 +4402,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") +@export def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: """Add padding to an array. @@ -4493,6 +4559,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], ### Array-creation functions +@export def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: """Join arrays along a new axis. @@ -4559,6 +4626,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], return concatenate(new_arrays, axis=axis, dtype=dtype) +@export @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: """Unstack an array along an axis. @@ -4599,6 +4667,8 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: ) return tuple(moveaxis(x, axis, 0)) + +@export def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: """Construct an array by repeating ``A`` along specified dimensions. @@ -4662,6 +4732,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, return lax.reshape(arr, shape, dimensions) +@export def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: """Join arrays along an existing axis. @@ -4725,6 +4796,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] +@export def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: """Join arrays along an existing axis. @@ -4765,6 +4837,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: return jax.numpy.concatenate(arrays, axis=axis) +@export def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Vertically stack arrays. @@ -4825,6 +4898,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) +@export def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Horizontally stack arrays. @@ -4885,6 +4959,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) +@export def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Stack arrays depth-wise. @@ -4945,6 +5020,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) +@export def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """Stack arrays column-wise. @@ -5005,6 +5081,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) +@export def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: """Construct an array by stacking slices of choice arrays. @@ -5129,6 +5206,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: return asarray(xs), 1 +@export @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: """Create an array from a list of blocks. @@ -5212,6 +5290,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 1 dimension. @@ -5266,6 +5345,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 2 dimensions. @@ -5329,6 +5409,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 3 dimensions. @@ -5405,6 +5486,7 @@ def _supports_buffer_protocol(obj): return True +@export def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5597,6 +5679,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x +@export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: @@ -5662,6 +5745,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result +@export def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: @@ -5743,6 +5827,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) +@export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. @@ -5791,6 +5876,7 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) +@export def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5833,6 +5919,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5875,6 +5962,7 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5924,6 +6012,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device +@export def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5972,6 +6061,7 @@ def full(shape: Any, fill_value: ArrayLike, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -6028,6 +6118,7 @@ def full_like(a: ArrayLike | DuckTypedArray, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of zeros. @@ -6064,6 +6155,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of ones. @@ -6100,6 +6192,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an empty array. @@ -6143,6 +6236,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") +@export def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: """Check if two arrays are element-wise equal. @@ -6184,6 +6278,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return reductions.all(eq) +@export def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: """Check if two arrays are element-wise equal. @@ -6224,6 +6319,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. +@export def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: r"""Convert a buffer into a 1-D JAX array. @@ -6271,6 +6367,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) +@export def fromfile(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromfile. @@ -6289,6 +6386,7 @@ def fromfile(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def fromiter(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromiter. @@ -6307,6 +6405,7 @@ def fromiter(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array: """Construct a JAX array via DLPack. @@ -6367,6 +6466,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, return from_dlpack(x, device=device, copy=copy) +@export def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: """Create an array from a function applied over indices. @@ -6453,6 +6553,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) +@export def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: """Convert a string of text into 1-D JAX array. @@ -6481,6 +6582,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) +@export def eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None, @@ -6560,6 +6662,7 @@ def _eye(N: DimSize, M: DimSize | None = None, return (i + offset == j).astype(dtype) +@export def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: """Create a square identity matrix @@ -6593,6 +6696,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: return eye(n, dtype=dtype) +@export def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -6760,6 +6864,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, @@ -6885,6 +6990,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result +@export def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: @@ -6970,6 +7076,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) +@export def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Generate geometrically-spaced values. @@ -7044,6 +7151,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) +@export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: """Construct N-dimensional grid arrays from N 1-dimensional vectors. @@ -7125,6 +7233,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output +@export @jit def i0(x: ArrayLike) -> Array: r"""Calculate modified Bessel function of first kind, zeroth order. @@ -7174,6 +7283,7 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) +@export def ix_(*args: ArrayLike) -> tuple[Array, ...]: """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. @@ -7237,6 +7347,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... +@export def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: """Generate arrays of grid indices. @@ -7287,6 +7398,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, return stack(output, 0) if output else array([], dtype=dtype) +@export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: """Construct an array from repeated elements. @@ -7431,6 +7543,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@export @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: @@ -7490,6 +7603,7 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) +@export def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: r"""Return an array with ones on and below the diagonal and zeros elsewhere. @@ -7546,6 +7660,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None return lax_internal._tri(dtype, (N, M), k) +@export @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: r"""Return lower triangle of an array. @@ -7607,6 +7722,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) +@export @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: r"""Return upper triangle of an array. @@ -7672,6 +7788,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) +@export @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -7737,6 +7854,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int return reductions.sum(a, axis=(-2, -1), dtype=dtype) +@export def mask_indices(n: int, mask_func: Callable[[ArrayLike, int], Array], k: int = 0, *, size: int | None = None) -> tuple[Array, Array]: @@ -7796,6 +7914,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) +@export def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of upper triangle of an array of size ``(n, m)``. @@ -7854,6 +7973,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of lower triangle of an array of size ``(n, m)``. @@ -7912,6 +8032,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. @@ -7969,6 +8090,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. @@ -8026,6 +8148,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: """Return a copy of the array with the diagonal overwritten. @@ -8107,6 +8230,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) +@export def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a multidimensional array. @@ -8142,6 +8266,8 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: .format(ndim)) return (lax.iota(int_, n),) * ndim + +@export def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a given array. @@ -8183,6 +8309,8 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) + +@export @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: @@ -8234,6 +8362,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] +@export def diag(v: ArrayLike, k: int = 0) -> Array: """Returns the specified diagonal or constructs a diagonal array. @@ -8297,6 +8426,8 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") + +@export def diagflat(v: ArrayLike, k: int = 0) -> Array: """Return a 2-D array with the flattened input array laid out on the diagonal. @@ -8353,6 +8484,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: # TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 +@export def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. @@ -8407,6 +8539,8 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] + +@export @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None @@ -8461,6 +8595,7 @@ def append( return concatenate([arr, values], axis=axis) +@export def delete( arr: ArrayLike, obj: ArrayLike | slice, @@ -8585,6 +8720,7 @@ def delete( return a[tuple(slice(None) for i in range(axis)) + (mask,)] +@export def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: """Insert entries into an array at specified indices. @@ -8684,6 +8820,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, return out +@export def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: @@ -8761,6 +8898,7 @@ def apply_along_axis( return func(arr) +@export def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: """Apply a function repeatedly over specified axes. @@ -8819,6 +8957,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -8908,6 +9047,7 @@ def dot(a: ArrayLike, b: ArrayLike, *, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9031,6 +9171,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, @@ -9079,6 +9220,7 @@ def vdot( preferred_element_type=preferred_element_type) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -9134,6 +9276,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) +@export def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, @@ -9279,6 +9422,7 @@ def einsum( out_type=None, ) -> Array: ... +@export def einsum( subscripts, /, *operands, @@ -9554,6 +9698,7 @@ def einsum_path( optimize: bool | str | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... +@export def einsum_path( subscripts, /, *operands, @@ -9787,6 +9932,7 @@ def _einsum( output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9843,6 +9989,7 @@ def inner( preferred_element_type=preferred_element_type) +@export @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """Compute the outer product of two arrays. @@ -9877,6 +10024,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: return ravel(a)[:, None] * ravel(b)[None, :] +@export @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): @@ -9977,6 +10125,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) +@export @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: """Compute the Kronecker product of two input arrays. @@ -10022,6 +10171,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) +@export @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False @@ -10085,6 +10235,7 @@ def vander( ### Misc +@export def argwhere( a: ArrayLike, *, @@ -10150,6 +10301,7 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) +@export def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the maximum value of an array. @@ -10205,6 +10357,7 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the minimum value of an array. @@ -10260,6 +10413,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def nanargmax( a: ArrayLike, axis: int | None = None, @@ -10327,6 +10481,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export def nanargmin( a: ArrayLike, axis: int | None = None, @@ -10387,6 +10542,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, @@ -10450,6 +10606,7 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result +@export @jit def sort_complex(a: ArrayLike) -> Array: """Return a sorted copy of complex array. @@ -10487,6 +10644,7 @@ def sort_complex(a: ArrayLike) -> Array: return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) +@export @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: """Sort a sequence of keys in lexicographic order. @@ -10564,6 +10722,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, @@ -10644,6 +10803,7 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices +@export @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns a partially-sorted copy of an array. @@ -10714,6 +10874,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) +@export @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns indices that partially sort an array. @@ -10818,6 +10979,8 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a + +@export def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: """Roll the elements of an array along a specified axis. @@ -10871,6 +11034,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) +@export @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """Roll the specified axis to a given position. @@ -10936,6 +11100,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) +@export @partial(jit, static_argnames=('axis', 'bitorder')) def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: """Pack array of bits into a uint8 array. @@ -11020,6 +11185,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar return swapaxes(packed, axis, -1) +@export @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -11111,6 +11277,7 @@ def unpackbits( return swapaxes(unpacked, axis, -1) +@export def take( a: ArrayLike, indices: ArrayLike, @@ -11268,6 +11435,7 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) +@export @partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, @@ -11462,6 +11630,7 @@ def _make_along_axis_idx(shape, indices, axis): return tuple_replace(_indices(shape, sparse=True), axis, indices) +@export @partial(jit, static_argnames=('axis', 'inplace', 'mode')) def put_along_axis( arr: ArrayLike, @@ -12206,6 +12375,7 @@ def _preprocess_slice( return start, step, slice_size +@export def blackman(M: int) -> Array: """Return a Blackman window of size M. @@ -12236,6 +12406,7 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) +@export def bartlett(M: int) -> Array: """Return a Bartlett window of size M. @@ -12266,6 +12437,7 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) +@export def hamming(M: int) -> Array: """Return a Hamming window of size M. @@ -12296,6 +12468,7 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) +@export def hanning(M: int) -> Array: """Return a Hanning window of size M. @@ -12326,6 +12499,7 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) +@export def kaiser(M: int, beta: ArrayLike) -> Array: """Return a Kaiser window of size M. @@ -12368,6 +12542,8 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) + +@export @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the greatest common divisor of two arrays. @@ -12414,6 +12590,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd +@export @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the least common multiple of two arrays. @@ -12461,6 +12638,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) +@export def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: """Return the elements of an array that satisfy a condition. @@ -12522,6 +12700,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value) +@export def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, *, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array: """Compress an array along a given axis using a boolean condition. @@ -12616,6 +12795,7 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(result, 0, axis) +@export @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, @@ -12774,6 +12954,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() +@export @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: r"""Compute the Pearson correlation coefficients. @@ -12903,6 +13084,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) +@export @partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: @@ -12992,6 +13174,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', return impl(asarray(a), asarray(v), side, dtype) # type: ignore +@export @partial(jit, static_argnames=('right', 'method')) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str | None = None) -> Array: @@ -13047,6 +13230,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, ) +@export def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: @@ -13154,6 +13338,7 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr +@export def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: """Update array elements based on a mask. @@ -13229,6 +13414,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) +@export def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: """Put elements into an array at given indices. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 10cc90575..19388b903 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -33,6 +33,10 @@ from jax._src.numpy import linalg from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, _where) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module + + +export = set_module('jax.numpy') @jit @@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) +@export def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: r"""Returns the roots of a polynomial given the coefficients ``p``. @@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: return _roots_with_zeros(p_arr, num_leading_zeros) +@export @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, full: bool = False, w: ArrayLike | None = None, cov: bool = False @@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, return c +@export @jit def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. @@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: return a +@export @partial(jit, static_argnames=['unroll']) def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. @@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: return y +@export @jit def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the sum of the two polynomials. @@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) +@export @partial(jit, static_argnames=('m',)) def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. @@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array return true_divide(concatenate((p_arr, k_arr)), coeff) +@export @partial(jit, static_argnames=('m',)) def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. @@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array: return p_arr[:-m] * coeff[::-1] +@export def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: r"""Returns the product of two polynomials. @@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - return convolve(a1_arr, a2_arr, mode='full') +@export def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: r"""Returns the quotient and remainder of polynomial division. @@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> return q, u_arr +@export @jit def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the difference of two polynomials. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 5acad86ea..bc85bc3e8 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -37,9 +37,11 @@ from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, - NumpyComplexWarning) + set_module, NumpyComplexWarning) +export = set_module('jax.numpy') + _all = builtins.all _lax_const = lax_internal._const @@ -222,6 +224,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) +@export def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: @@ -296,6 +299,7 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) + @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -307,6 +311,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None initial=initial, where_=where, promote_integers=promote_integers) +@export def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -391,6 +396,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmax) +@export def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -473,6 +479,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmin) +@export def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -552,6 +559,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether all array elements along a given axis evaluate to True. @@ -608,6 +616,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether any of the array elements along a given axis evaluate to True. @@ -714,6 +723,7 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) +@export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -721,6 +731,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None, return min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) +@export def amax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -740,6 +751,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): return size +@export def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -843,6 +855,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... +@export def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: """Compute the weighed average. @@ -953,6 +966,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg +@export def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1093,6 +1107,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) +@export def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1185,6 +1200,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) +@export def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: r"""Return the peak-to-peak range along a given axis. @@ -1236,6 +1252,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, return lax.sub(x, y) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: @@ -1295,6 +1312,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], return out +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1377,6 +1395,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1459,6 +1478,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1542,6 +1562,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1625,6 +1646,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: @@ -1716,6 +1738,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out return td +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1818,6 +1841,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: return lax.convert_element_type(result, dtype) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1939,6 +1963,7 @@ def _cumulative_reduction( return result +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1975,6 +2000,7 @@ def cumsum(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2010,6 +2036,7 @@ def cumprod(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2059,6 +2086,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None, fill_nan=True, fill_value=0) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2115,6 +2143,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, a, axis, dtype, out, promote_integers=True) +@export def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2176,6 +2205,7 @@ def cumulative_sum( return out +@export def cumulative_prod( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2239,6 +2269,7 @@ def cumulative_prod( # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2295,6 +2326,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2475,7 +2507,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2531,7 +2565,9 @@ def percentile(a: ArrayLike, q: ArrayLike, return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2591,6 +2627,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, method=method, keepdims=keepdims) +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, @@ -2642,6 +2679,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, keepdims=keepdims, method='midpoint') +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 6491a7617..0d5ea905b 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -35,10 +35,12 @@ from jax._src.numpy.lax_numpy import ( from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import check_arraylike, promote_dtypes -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike +export = set_module('jax.numpy') + _lax_const = lax_internal._const @@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: return arr, num_unique1 + num_unique2 +@export def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) +@export def union1d(ar1: ArrayLike, ar2: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set union of two 1D arrays. @@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, return where(arange(len(vals)) < num_results, vals, fill_value) +@export def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. @@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as return vals +@export def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: @@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d +@export def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = False, invert: bool = False, *, method='auto') -> Array: @@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo return ret[0] if len(ret) == 1 else ret +@export def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, axis: int | None = None, *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): @@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple): inverse_indices: Array +@export def unique_all(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueAllResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) +@export def unique_counts(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueCountsResult: """Return unique values from x, along with counts. @@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, return _UniqueCountsResult(values=values, counts=counts) +@export def unique_inverse(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueInverseResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) +@export def unique_values(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Return unique values from x, along with indices, inverse indices, and counts. diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 27e2973b2..5dbd67e62 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -33,6 +33,8 @@ from jax._src.util import canonicalize_axis, set_module import numpy as np +export = set_module("jax.numpy") + _AT_INPLACE_WARNING = """\ Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. @@ -40,7 +42,7 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. """ -@set_module('jax.numpy') +@export class ufunc: """Universal functions which operation element-by-element on arrays. @@ -586,6 +588,7 @@ class ufunc: return result.reshape(*np.shape(A), *np.shape(B)) +@export def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, *, identity: Any = None) -> ufunc: """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a844ecbc2..bbbce9733 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -38,6 +38,10 @@ from jax._src.numpy.util import ( promote_shapes, _where, check_no_float0s) from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy import reductions +from jax._src.util import set_module + + +export = set_module('jax.numpy') _lax_const = lax._const @@ -75,6 +79,7 @@ def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None, return decorator +@export @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -119,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array: return lax.abs(*promote_args_inexact('fabs', x)) +@export @partial(jit, inline=True) def bitwise_invert(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_invert', x)) +@export @partial(jit, inline=True) def bitwise_not(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_not', x)) +@export @partial(jit, inline=True) def invert(x: ArrayLike, /) -> Array: """Compute the bitwise inversion of an input. @@ -223,6 +231,7 @@ def negative(x: ArrayLike, /) -> Array: return lax.neg(*promote_args('negative', x)) +@export @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: """Return element-wise positive values of the input. @@ -271,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array: return lax.asarray(*promote_args('positive', x)) +@export @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. @@ -321,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) +@export @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. @@ -359,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array: return lax.floor(*promote_args_inexact('floor', x)) +@export @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. @@ -397,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array: return lax.ceil(*promote_args_inexact('ceil', x)) +@export @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. @@ -438,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array: return lax.exp(*promote_args_inexact('exp', x)) +@export @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. @@ -475,6 +489,7 @@ def log(x: ArrayLike, /) -> Array: return lax.log(*promote_args_inexact('log', x)) +@export @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: """Calculate ``exp(x)-1`` of each element of the input. @@ -519,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array: return lax.expm1(*promote_args_inexact('expm1', x)) +@export @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: """Calculates element-wise logarithm of one plus input, ``log(x+1)``. @@ -559,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array: return lax.log1p(*promote_args_inexact('log1p', x)) +@export @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. @@ -590,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array: return lax.sin(*promote_args_inexact('sin', x)) +@export @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. @@ -620,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array: return lax.cos(*promote_args_inexact('cos', x)) +@export @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. @@ -650,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array: return lax.tan(*promote_args_inexact('tan', x)) +@export @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. @@ -691,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array: return lax.asin(*promote_args_inexact('arcsin', x)) +@export @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. @@ -733,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array: return lax.acos(*promote_args_inexact('arccos', x)) +@export @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric tangent of input. @@ -773,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array: return lax.atan(*promote_args_inexact('arctan', x)) +@export @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic sine of input. @@ -827,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array: return lax.sinh(*promote_args_inexact('sinh', x)) +@export @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic cosine of input. @@ -880,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array: return lax.cosh(*promote_args_inexact('cosh', x)) +@export @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic sine of input. @@ -929,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array: return lax.asinh(*promote_args_inexact('arcsinh', x)) +@export @jit def arccosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic cosine of input. @@ -984,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array: return result +@export @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic tangent of input. @@ -1037,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array: return lax.tanh(*promote_args_inexact('tanh', x)) +@export @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic tangent of input. @@ -1085,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) +@export @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. @@ -1117,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array: return lax.sqrt(*promote_args_inexact('sqrt', x)) +@export @partial(jit, inline=True) def cbrt(x: ArrayLike, /) -> Array: """Calculates element-wise cube root of the input array. @@ -1144,6 +1174,7 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) + def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.add.at.""" if a.dtype == bool: @@ -1152,6 +1183,7 @@ def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: return a.at[indices].add(b).astype(bool) return a.at[indices].add(b) + @binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. @@ -1182,6 +1214,7 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.multiply.at.""" if a.dtype == bool: @@ -1191,6 +1224,7 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: else: return a.at[indices].mul(b) + @binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. @@ -1221,6 +1255,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) + @binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. @@ -1250,6 +1285,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. @@ -1279,6 +1315,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. @@ -1309,6 +1346,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) +@export @partial(jit, inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. @@ -1364,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.shift_left(*promote_args_numeric("left_shift", x, y)) +@export @partial(jit, inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) +@export @partial(jit, inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. @@ -1419,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.eq(*promote_args("equal", x, y)) +@export @partial(jit, inline=True) def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x != y``. @@ -1472,6 +1513,7 @@ def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.subtract.at.""" return a.at[indices].subtract(b) + @binary_ufunc(identity=None, at=_subtract_at) def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. @@ -1502,6 +1544,7 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.sub(*promote_args("subtract", x, y)) +@export @partial(jit, inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the arctangent of x1/x2, choosing the correct quadrant. @@ -1557,6 +1600,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) +@export @partial(jit, inline=True) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1617,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) +@export @partial(jit, inline=True) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1676,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.max(*promote_args("maximum", x, y)) +@export @partial(jit, inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: """Calculate element-wise base ``x`` exponential of ``y``. @@ -1722,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.pow(*promote_args_inexact("float_power", x, y)) +@export @partial(jit, inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise next floating point value after ``x`` towards ``y``. @@ -1749,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) +@export @partial(jit, inline=True) def spacing(x: ArrayLike, /) -> Array: """Return the spacing between ``x`` and the next adjacent number. @@ -1856,6 +1904,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) +@export @partial(jit, inline=True) def logical_not(x: ArrayLike, /) -> Array: """Compute NOT bool(x) element-wise. @@ -1901,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], lax_op(x.real, y.real)) return lax_op(x, y) + +@export @partial(jit, inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x >= y``. @@ -1946,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) +@export @partial(jit, inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x > y``. @@ -1992,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.gt, *promote_args("greater", x, y)) +@export @partial(jit, inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x <= y``. @@ -2038,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) +@export @partial(jit, inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x < y``. @@ -2083,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array: """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) + # Array API aliases +@export @partial(jit, inline=True) def acos(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccos`""" return arccos(*promote_args('acos', x)) + +@export @partial(jit, inline=True) def acosh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccosh`""" return arccosh(*promote_args('acosh', x)) + +@export @partial(jit, inline=True) def asin(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsin`""" return arcsin(*promote_args('asin', x)) + +@export @partial(jit, inline=True) def asinh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsinh`""" return arcsinh(*promote_args('asinh', x)) + +@export @partial(jit, inline=True) def atan(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan`""" return arctan(*promote_args('atan', x)) + +@export @partial(jit, inline=True) def atanh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctanh`""" return arctanh(*promote_args('atanh', x)) + +@export @partial(jit, inline=True) def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" return arctan2(*promote_args('atan2', x1, x2)) + +@export @jit def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value @@ -2154,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array: # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') + +@export @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. @@ -2205,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_fn(x1, x2) +@export @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.right_shift`.""" return right_shift(x1, x2) +@export @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. @@ -2246,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array: return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +@export @partial(jit, inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x) +@export @jit def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer @@ -2291,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) +@export @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. @@ -2330,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) +@export @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the division of x1 by x2 element-wise @@ -2368,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.div(x1, x2) +@export def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.true_divide`.""" return true_divide(x1, x2) +@export @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise @@ -2427,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _float_divmod(x1, x2)[0] +@export @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise @@ -2481,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@export def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise base ``x1`` exponential of ``x2``. @@ -2565,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +@export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.power`""" return power(x1, x2) @@ -2604,6 +2687,7 @@ def _pow_int_int(x1, x2): return acc +@export @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2630,6 +2714,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) +@export @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2662,6 +2747,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return logaddexp(x1 * ln2, x2 * ln2) / ln2 +@export @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. @@ -2684,6 +2770,7 @@ def log2(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) +@export @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise @@ -2707,6 +2794,7 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) +@export @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: """Calculate element-wise base-2 exponential of input. @@ -2741,6 +2829,7 @@ def exp2(x: ArrayLike, /) -> Array: return lax.exp2(x) +@export @jit def signbit(x: ArrayLike, /) -> Array: """Return the sign bit of array elements. @@ -2813,6 +2902,7 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 +@export @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute x1 * 2 ** x2 @@ -2862,6 +2952,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(isinf(x1) | (x1 == 0), x1, x) +@export @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: """Split floating point values into mantissa and twos exponent. @@ -2915,6 +3006,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Returns element-wise remainder of the division. @@ -2962,11 +3054,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) +@export def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.remainder`""" return remainder(x1, x2) +@export @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise floating-point modulo operation. @@ -3008,6 +3102,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) +@export @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: """Calculate element-wise square of the input array. @@ -3057,6 +3152,7 @@ def square(x: ArrayLike, /) -> Array: return lax.square(x) +@export @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: r"""Convert angles from degrees to radians. @@ -3091,6 +3187,7 @@ def deg2rad(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, np.pi / 180)) +@export @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: r"""Convert angles from radians to degrees. @@ -3126,15 +3223,19 @@ def rad2deg(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, 180 / np.pi)) +@export def degrees(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.rad2deg`""" return rad2deg(x) + +@export def radians(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.deg2rad`""" return deg2rad(x) +@export @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. @@ -3164,11 +3265,13 @@ def conjugate(x: ArrayLike, /) -> Array: return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) +@export def conj(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.conjugate`""" return conjugate(x) +@export @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. @@ -3200,6 +3303,7 @@ def imag(val: ArrayLike, /) -> Array: return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) +@export @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. @@ -3231,6 +3335,7 @@ def real(val: ArrayLike, /) -> Array: return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) +@export @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: """Return element-wise fractional and integral parts of the input array. @@ -3264,6 +3369,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole +@export @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. @@ -3304,6 +3410,7 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) +@export @jit def isinf(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is infinite. @@ -3359,6 +3466,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) +@export def isposinf(x, /, out=None): """ Return boolean array indicating whether each element of input is positive infinite. @@ -3392,6 +3500,7 @@ def isposinf(x, /, out=None): return _isposneginf(np.inf, x, out) +@export def isneginf(x, /, out=None): """ Return boolean array indicating whether each element of input is negative infinite. @@ -3425,6 +3534,7 @@ def isneginf(x, /, out=None): return _isposneginf(-np.inf, x, out) +@export @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. @@ -3459,6 +3569,7 @@ def isnan(x: ArrayLike, /) -> Array: return lax.ne(x, x) +@export @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the heaviside step function. @@ -3508,6 +3619,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) +@export @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: r""" @@ -3556,6 +3668,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(idx_inf, _lax_const(x, np.inf), x) +@export @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: """Calculate element-wise reciprocal of the input. @@ -3589,6 +3702,7 @@ def reciprocal(x: ArrayLike, /) -> Array: return lax.integer_pow(x, -1) +@export @jit def sinc(x: ArrayLike, /) -> Array: r"""Calculate the normalized sinc function. diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e7a0e2142..f1e6d399b 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,9 +23,11 @@ from jax._src import api from jax._src import config from jax import lax from jax._src.numpy import lax_numpy as jnp -from jax._src.util import safe_map as map, safe_zip as zip +from jax._src.util import set_module, safe_map as map, safe_zip as zip +export = set_module('jax.numpy') + # See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html _DIMENSION_NAME = r'\w+' _CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME) @@ -185,6 +187,7 @@ def _apply_excluded(func: Callable[..., Any], return new_func, dynamic_args, dynamic_kwargs +@export def vectorize(pyfunc, *, excluded=frozenset(), signature=None): """Define a vectorized function with broadcasting. diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 71d48c2b1..9bc8d0f6d 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -32,6 +32,14 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. _mod("jax.errors", exclude=["JaxRuntimeError"]), + _mod( + "jax.numpy", + exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating", + "dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo", + "flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim", + "number", "object_", "printoptions", "save", "savez", "set_printoptions", + "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] + ), _mod("jax.nn.initializers"), _mod( "jax.tree_util", @@ -46,7 +54,8 @@ class PackageStructureTest(jtu.JaxTestCase): if name not in include and (name.startswith('_') or name in exclude): continue obj = getattr(module, name) - if isinstance(obj, types.ModuleType): + if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)): + # No __module__ attribute expected. continue self.assertEqual(obj.__module__, module_name, f"{obj} has {obj.__module__=}, expected {module_name}")