Set __module__ attribute for objects in jax.numpy

This commit is contained in:
Jake VanderPlas 2024-11-15 06:03:54 -08:00
parent 9a0e9e55d8
commit f652b6ad6a
10 changed files with 396 additions and 10 deletions

View File

@ -343,6 +343,7 @@ _types_for_issubdtype = (type, np.dtype, ExtendedDType)
# TODO(jakevdp): consider whether to disallow None here. We allow it
# because np.issubdtype allows it (and treats it as equivalent to float64).
@set_module('jax.numpy')
def issubdtype(a: DTypeLike | ExtendedDType | None,
b: DTypeLike | ExtendedDType | None) -> bool:
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
@ -458,6 +459,7 @@ _dtype_kinds: dict[str, set] = {
}
@set_module('jax.numpy')
def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
@ -650,6 +652,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
"JAX's internal logic; please report it to the JAX maintainers."
)
@set_module('jax.numpy')
def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
"""Returns the type to which a binary operation should cast its arguments.

View File

@ -24,10 +24,14 @@ from jax._src.numpy.lax_numpy import (
arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose
)
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module
import numpy as np
export = set_module('jax.numpy')
__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"]
@ -87,7 +91,7 @@ class _Mgrid:
return stack(output_arr, 0)
mgrid = _Mgrid()
mgrid = export(_Mgrid())
class _Ogrid:
@ -129,7 +133,7 @@ class _Ogrid:
return meshgrid(*output, indexing='ij', sparse=True)
ogrid = _Ogrid()
ogrid = export(_Ogrid())
_IndexType = Union[ArrayLike, str, slice]
@ -279,7 +283,7 @@ class RClass(_AxisConcat):
op_name = "r_"
r_ = RClass()
r_ = export(RClass())
class CClass(_AxisConcat):
@ -327,7 +331,7 @@ class CClass(_AxisConcat):
op_name = "c_"
c_ = CClass()
c_ = export(CClass())
s_ = np.s_

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,10 @@ from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module
export = set_module('jax.numpy')
@jit
@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
@export
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
r"""Returns the roots of a polynomial given the coefficients ``p``.
@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
return _roots_with_zeros(p_arr, num_leading_zeros)
@export
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
full: bool = False, w: ArrayLike | None = None, cov: bool = False
@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
return c
@export
@jit
def poly(seq_of_zeros: ArrayLike) -> Array:
r"""Returns the coefficients of a polynomial for the given sequence of roots.
@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
return a
@export
@partial(jit, static_argnames=['unroll'])
def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
r"""Evaluates the polynomial at specific values.
@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
return y
@export
@jit
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the sum of the two polynomials.
@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr)
@export
@partial(jit, static_argnames=('m',))
def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array:
r"""Returns the coefficients of the integration of specified order of a polynomial.
@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array
return true_divide(concatenate((p_arr, k_arr)), coeff)
@export
@partial(jit, static_argnames=('m',))
def polyder(p: ArrayLike, m: int = 1) -> Array:
r"""Returns the coefficients of the derivative of specified order of a polynomial.
@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
return p_arr[:-m] * coeff[::-1]
@export
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
r"""Returns the product of two polynomials.
@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
return convolve(a1_arr, a2_arr, mode='full')
@export
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
r"""Returns the quotient and remainder of polynomial division.
@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
return q, u_arr
@export
@jit
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the difference of two polynomials.

View File

@ -37,9 +37,11 @@ from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
NumpyComplexWarning)
set_module, NumpyComplexWarning)
export = set_module('jax.numpy')
_all = builtins.all
_lax_const = lax_internal._const
@ -222,6 +224,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
promote_integers=promote_integers)
@export
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
@ -296,6 +299,7 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
@ -307,6 +311,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
initial=initial, where_=where, promote_integers=promote_integers)
@export
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
@ -391,6 +396,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@export
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -473,6 +479,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@export
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -552,6 +559,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims, where_=where)
@export
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether all array elements along a given axis evaluate to True.
@ -608,6 +616,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims, where_=where)
@export
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether any of the array elements along a given axis evaluate to True.
@ -714,6 +723,7 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
initial=initial, where_=where)
@export
def amin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -721,6 +731,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None,
return min(a, axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
def amax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -740,6 +751,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
return size
@export
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
@ -843,6 +855,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
@export
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
"""Compute the weighed average.
@ -953,6 +966,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
return avg
@export
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
@ -1093,6 +1107,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
return _upcast_f16(computation_dtype), np.dtype(dtype)
@export
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
@ -1185,6 +1200,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))
@export
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
r"""Return the peak-to-peak range along a given axis.
@ -1236,6 +1252,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
return lax.sub(x, y)
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
@ -1295,6 +1312,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
return out
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -1377,6 +1395,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -1459,6 +1478,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -1542,6 +1562,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -1625,6 +1646,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, where: ArrayLike | None = None) -> Array:
@ -1716,6 +1738,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
return td
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@ -1818,6 +1841,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
return lax.convert_element_type(result, dtype)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@ -1939,6 +1963,7 @@ def _cumulative_reduction(
return result
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -1975,6 +2000,7 @@ def cumsum(a: ArrayLike, axis: int | None = None,
return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -2010,6 +2036,7 @@ def cumprod(a: ArrayLike, axis: int | None = None,
return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -2059,6 +2086,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None,
fill_nan=True, fill_value=0)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -2115,6 +2143,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None,
a, axis, dtype, out, promote_integers=True)
@export
def cumulative_sum(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
@ -2176,6 +2205,7 @@ def cumulative_sum(
return out
@export
def cumulative_prod(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
@ -2239,6 +2269,7 @@ def cumulative_prod(
# Quantiles
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
@ -2295,6 +2326,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
@ -2475,7 +2507,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
@ -2531,7 +2565,9 @@ def percentile(a: ArrayLike, q: ArrayLike,
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
@ -2591,6 +2627,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
method=method, keepdims=keepdims)
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
@ -2642,6 +2679,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
keepdims=keepdims, method='midpoint')
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,

View File

@ -35,10 +35,12 @@ from jax._src.numpy.lax_numpy import (
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike, promote_dtypes
from jax._src.util import canonicalize_axis
from jax._src.util import canonicalize_axis, set_module
from jax._src.typing import Array, ArrayLike
export = set_module('jax.numpy')
_lax_const = lax_internal._const
@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]:
return arr, num_unique1 + num_unique2
@export
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set difference of two 1D arrays.
@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
@export
def union1d(ar1: ArrayLike, ar2: ArrayLike,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set union of two 1D arrays.
@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *,
return where(arange(len(vals)) < num_results, vals, fill_value)
@export
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *,
size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set-wise xor of elements in two arrays.
@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
return vals
@export
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]:
@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return int1d
@export
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = False, invert: bool = False, *,
method='auto') -> Array:
@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
return ret[0] if len(ret) == 1 else ret
@export
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: int | None = None,
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple):
inverse_indices: Array
@export
def unique_all(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueAllResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
@export
def unique_counts(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueCountsResult:
"""Return unique values from x, along with counts.
@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
return _UniqueCountsResult(values=values, counts=counts)
@export
def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueInverseResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
@export
def unique_values(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> Array:
"""Return unique values from x, along with indices, inverse indices, and counts.

View File

@ -33,6 +33,8 @@ from jax._src.util import canonicalize_axis, set_module
import numpy as np
export = set_module("jax.numpy")
_AT_INPLACE_WARNING = """\
Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like
np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
@ -40,7 +42,7 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
"""
@set_module('jax.numpy')
@export
class ufunc:
"""Universal functions which operation element-by-element on arrays.
@ -586,6 +588,7 @@ class ufunc:
return result.reshape(*np.shape(A), *np.shape(B))
@export
def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
*, identity: Any = None) -> ufunc:
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

View File

@ -38,6 +38,10 @@ from jax._src.numpy.util import (
promote_shapes, _where, check_no_float0s)
from jax._src.numpy.ufunc_api import ufunc
from jax._src.numpy import reductions
from jax._src.util import set_module
export = set_module('jax.numpy')
_lax_const = lax._const
@ -75,6 +79,7 @@ def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None,
return decorator
@export
@partial(jit, inline=True)
def fabs(x: ArrayLike, /) -> Array:
"""Compute the element-wise absolute values of the real-valued input.
@ -119,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array:
return lax.abs(*promote_args_inexact('fabs', x))
@export
@partial(jit, inline=True)
def bitwise_invert(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.invert`."""
return lax.bitwise_not(*promote_args('bitwise_invert', x))
@export
@partial(jit, inline=True)
def bitwise_not(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.invert`."""
return lax.bitwise_not(*promote_args('bitwise_not', x))
@export
@partial(jit, inline=True)
def invert(x: ArrayLike, /) -> Array:
"""Compute the bitwise inversion of an input.
@ -223,6 +231,7 @@ def negative(x: ArrayLike, /) -> Array:
return lax.neg(*promote_args('negative', x))
@export
@partial(jit, inline=True)
def positive(x: ArrayLike, /) -> Array:
"""Return element-wise positive values of the input.
@ -271,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array:
return lax.asarray(*promote_args('positive', x))
@export
@partial(jit, inline=True)
def sign(x: ArrayLike, /) -> Array:
r"""Return an element-wise indication of sign of the input.
@ -321,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array:
return lax.sign(*promote_args('sign', x))
@export
@partial(jit, inline=True)
def floor(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer downwards.
@ -359,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array:
return lax.floor(*promote_args_inexact('floor', x))
@export
@partial(jit, inline=True)
def ceil(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer upwards.
@ -397,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array:
return lax.ceil(*promote_args_inexact('ceil', x))
@export
@partial(jit, inline=True)
def exp(x: ArrayLike, /) -> Array:
"""Calculate element-wise exponential of the input.
@ -438,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array:
return lax.exp(*promote_args_inexact('exp', x))
@export
@partial(jit, inline=True)
def log(x: ArrayLike, /) -> Array:
"""Calculate element-wise natural logarithm of the input.
@ -475,6 +489,7 @@ def log(x: ArrayLike, /) -> Array:
return lax.log(*promote_args_inexact('log', x))
@export
@partial(jit, inline=True)
def expm1(x: ArrayLike, /) -> Array:
"""Calculate ``exp(x)-1`` of each element of the input.
@ -519,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array:
return lax.expm1(*promote_args_inexact('expm1', x))
@export
@partial(jit, inline=True)
def log1p(x: ArrayLike, /) -> Array:
"""Calculates element-wise logarithm of one plus input, ``log(x+1)``.
@ -559,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array:
return lax.log1p(*promote_args_inexact('log1p', x))
@export
@partial(jit, inline=True)
def sin(x: ArrayLike, /) -> Array:
"""Compute a trigonometric sine of each element of input.
@ -590,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array:
return lax.sin(*promote_args_inexact('sin', x))
@export
@partial(jit, inline=True)
def cos(x: ArrayLike, /) -> Array:
"""Compute a trigonometric cosine of each element of input.
@ -620,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array:
return lax.cos(*promote_args_inexact('cos', x))
@export
@partial(jit, inline=True)
def tan(x: ArrayLike, /) -> Array:
"""Compute a trigonometric tangent of each element of input.
@ -650,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array:
return lax.tan(*promote_args_inexact('tan', x))
@export
@partial(jit, inline=True)
def arcsin(x: ArrayLike, /) -> Array:
r"""Compute element-wise inverse of trigonometric sine of input.
@ -691,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array:
return lax.asin(*promote_args_inexact('arcsin', x))
@export
@partial(jit, inline=True)
def arccos(x: ArrayLike, /) -> Array:
"""Compute element-wise inverse of trigonometric cosine of input.
@ -733,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array:
return lax.acos(*promote_args_inexact('arccos', x))
@export
@partial(jit, inline=True)
def arctan(x: ArrayLike, /) -> Array:
"""Compute element-wise inverse of trigonometric tangent of input.
@ -773,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array:
return lax.atan(*promote_args_inexact('arctan', x))
@export
@partial(jit, inline=True)
def sinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic sine of input.
@ -827,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array:
return lax.sinh(*promote_args_inexact('sinh', x))
@export
@partial(jit, inline=True)
def cosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic cosine of input.
@ -880,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array:
return lax.cosh(*promote_args_inexact('cosh', x))
@export
@partial(jit, inline=True)
def arcsinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic sine of input.
@ -929,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array:
return lax.asinh(*promote_args_inexact('arcsinh', x))
@export
@jit
def arccosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic cosine of input.
@ -984,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array:
return result
@export
@partial(jit, inline=True)
def tanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic tangent of input.
@ -1037,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array:
return lax.tanh(*promote_args_inexact('tanh', x))
@export
@partial(jit, inline=True)
def arctanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic tangent of input.
@ -1085,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array:
return lax.atanh(*promote_args_inexact('arctanh', x))
@export
@partial(jit, inline=True)
def sqrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise non-negative square root of the input array.
@ -1117,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array:
return lax.sqrt(*promote_args_inexact('sqrt', x))
@export
@partial(jit, inline=True)
def cbrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise cube root of the input array.
@ -1144,6 +1174,7 @@ def cbrt(x: ArrayLike, /) -> Array:
"""
return lax.cbrt(*promote_args_inexact('cbrt', x))
def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.add.at."""
if a.dtype == bool:
@ -1152,6 +1183,7 @@ def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array:
return a.at[indices].add(b).astype(bool)
return a.at[indices].add(b)
@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Add two arrays element-wise.
@ -1182,6 +1214,7 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.multiply.at."""
if a.dtype == bool:
@ -1191,6 +1224,7 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array:
else:
return a.at[indices].mul(b)
@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Multiply two arrays element-wise.
@ -1221,6 +1255,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and)
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise AND operation elementwise.
@ -1250,6 +1285,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or)
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise OR operation elementwise.
@ -1279,6 +1315,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor)
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise XOR operation elementwise.
@ -1309,6 +1346,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
@export
@partial(jit, inline=True)
def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise.
@ -1364,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.shift_left(*promote_args_numeric("left_shift", x, y))
@export
@partial(jit, inline=True)
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.left_shift`."""
return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y))
@export
@partial(jit, inline=True)
def equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Returns element-wise truth value of ``x == y``.
@ -1419,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.eq(*promote_args("equal", x, y))
@export
@partial(jit, inline=True)
def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Returns element-wise truth value of ``x != y``.
@ -1472,6 +1513,7 @@ def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.subtract.at."""
return a.at[indices].subtract(b)
@binary_ufunc(identity=None, at=_subtract_at)
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Subtract two arrays element-wise.
@ -1502,6 +1544,7 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.sub(*promote_args("subtract", x, y))
@export
@partial(jit, inline=True)
def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Compute the arctangent of x1/x2, choosing the correct quadrant.
@ -1557,6 +1600,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.atan2(*promote_args_inexact("arctan2", x1, x2))
@export
@partial(jit, inline=True)
def minimum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise minimum of the input arrays.
@ -1617,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.min(*promote_args("minimum", x, y))
@export
@partial(jit, inline=True)
def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise maximum of the input arrays.
@ -1676,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.max(*promote_args("maximum", x, y))
@export
@partial(jit, inline=True)
def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x`` exponential of ``y``.
@ -1722,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.pow(*promote_args_inexact("float_power", x, y))
@export
@partial(jit, inline=True)
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise next floating point value after ``x`` towards ``y``.
@ -1749,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
@export
@partial(jit, inline=True)
def spacing(x: ArrayLike, /) -> Array:
"""Return the spacing between ``x`` and the next adjacent number.
@ -1856,6 +1904,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
@export
@partial(jit, inline=True)
def logical_not(x: ArrayLike, /) -> Array:
"""Compute NOT bool(x) element-wise.
@ -1901,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array],
lax_op(x.real, y.real))
return lax_op(x, y)
@export
@partial(jit, inline=True)
def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x >= y``.
@ -1946,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y))
@export
@partial(jit, inline=True)
def greater(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x > y``.
@ -1992,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.gt, *promote_args("greater", x, y))
@export
@partial(jit, inline=True)
def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x <= y``.
@ -2038,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.le, *promote_args("less_equal", x, y))
@export
@partial(jit, inline=True)
def less(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x < y``.
@ -2083,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return _complex_comparison(lax.lt, *promote_args("less", x, y))
# Array API aliases
@export
@partial(jit, inline=True)
def acos(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arccos`"""
return arccos(*promote_args('acos', x))
@export
@partial(jit, inline=True)
def acosh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arccosh`"""
return arccosh(*promote_args('acosh', x))
@export
@partial(jit, inline=True)
def asin(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arcsin`"""
return arcsin(*promote_args('asin', x))
@export
@partial(jit, inline=True)
def asinh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arcsinh`"""
return arcsinh(*promote_args('asinh', x))
@export
@partial(jit, inline=True)
def atan(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctan`"""
return arctan(*promote_args('atan', x))
@export
@partial(jit, inline=True)
def atanh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctanh`"""
return arctanh(*promote_args('atanh', x))
@export
@partial(jit, inline=True)
def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctan2`"""
return arctan2(*promote_args('atan2', x1, x2))
@export
@jit
def bitwise_count(x: ArrayLike, /) -> Array:
r"""Counts the number of 1 bits in the binary representation of the absolute value
@ -2154,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array:
# Following numpy we take the absolute value and return uint8.
return lax.population_count(abs(x)).astype('uint8')
@export
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
@ -2205,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax_fn(x1, x2)
@export
@partial(jit, inline=True)
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.right_shift`."""
return right_shift(x1, x2)
@export
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
r"""Calculate the absolute value element-wise.
@ -2246,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array:
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
@export
@partial(jit, inline=True)
def abs(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.absolute`."""
return absolute(x)
@export
@jit
def rint(x: ArrayLike, /) -> Array:
"""Rounds the elements of x to the nearest integer
@ -2291,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
@export
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Copies the sign of each element in ``x2`` to the corresponding element in ``x1``.
@ -2330,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
@export
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculates the division of x1 by x2 element-wise
@ -2368,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.div(x1, x2)
@export
def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.true_divide`."""
return true_divide(x1, x2)
@export
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculates the floor division of x1 by x2 element-wise
@ -2427,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _float_divmod(x1, x2)[0]
@export
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
"""Calculates the integer quotient and remainder of x1 by x2 element-wise
@ -2481,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod
@export
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x1`` exponential of ``x2``.
@ -2565,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# Handle cases #2 and #3 under a jit:
return _power(x1, x2)
@export
def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.power`"""
return power(x1, x2)
@ -2604,6 +2687,7 @@ def _pow_int_int(x1, x2):
return acc
@export
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute ``log(exp(x1) + exp(x2))`` avoiding overflow.
@ -2630,6 +2714,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax_other.logaddexp(x1, x2)
@export
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
@ -2662,6 +2747,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return logaddexp(x1 * ln2, x2 * ln2) / ln2
@export
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
"""Calculates the base-2 logarithm of ``x`` element-wise.
@ -2684,6 +2770,7 @@ def log2(x: ArrayLike, /) -> Array:
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@export
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
"""Calculates the base-10 logarithm of x element-wise
@ -2707,6 +2794,7 @@ def log10(x: ArrayLike, /) -> Array:
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@export
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
"""Calculate element-wise base-2 exponential of input.
@ -2741,6 +2829,7 @@ def exp2(x: ArrayLike, /) -> Array:
return lax.exp2(x)
@export
@jit
def signbit(x: ArrayLike, /) -> Array:
"""Return the sign bit of array elements.
@ -2813,6 +2902,7 @@ def _normalize_float(x):
return lax.bitcast_convert_type(x1, int_type), x2
@export
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute x1 * 2 ** x2
@ -2862,6 +2952,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(isinf(x1) | (x1 == 0), x1, x)
@export
@jit
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
"""Split floating point values into mantissa and twos exponent.
@ -2915,6 +3006,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
@export
@jit
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Returns element-wise remainder of the division.
@ -2962,11 +3054,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
@export
def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.remainder`"""
return remainder(x1, x2)
@export
@jit
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculate element-wise floating-point modulo operation.
@ -3008,6 +3102,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.rem(*promote_args_numeric("fmod", x1, x2))
@export
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
"""Calculate element-wise square of the input array.
@ -3057,6 +3152,7 @@ def square(x: ArrayLike, /) -> Array:
return lax.square(x)
@export
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
r"""Convert angles from degrees to radians.
@ -3091,6 +3187,7 @@ def deg2rad(x: ArrayLike, /) -> Array:
return lax.mul(x, _lax_const(x, np.pi / 180))
@export
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
r"""Convert angles from radians to degrees.
@ -3126,15 +3223,19 @@ def rad2deg(x: ArrayLike, /) -> Array:
return lax.mul(x, _lax_const(x, 180 / np.pi))
@export
def degrees(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.rad2deg`"""
return rad2deg(x)
@export
def radians(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.deg2rad`"""
return deg2rad(x)
@export
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
"""Return element-wise complex-conjugate of the input.
@ -3164,11 +3265,13 @@ def conjugate(x: ArrayLike, /) -> Array:
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
@export
def conj(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.conjugate`"""
return conjugate(x)
@export
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
"""Return element-wise imaginary of part of the complex argument.
@ -3200,6 +3303,7 @@ def imag(val: ArrayLike, /) -> Array:
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@export
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
"""Return element-wise real part of the complex argument.
@ -3231,6 +3335,7 @@ def real(val: ArrayLike, /) -> Array:
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
@export
@jit
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
"""Return element-wise fractional and integral parts of the input array.
@ -3264,6 +3369,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
return x - whole, whole
@export
@partial(jit, inline=True)
def isfinite(x: ArrayLike, /) -> Array:
"""Return a boolean array indicating whether each element of input is finite.
@ -3304,6 +3410,7 @@ def isfinite(x: ArrayLike, /) -> Array:
return lax.full_like(x, True, dtype=np.bool_)
@export
@jit
def isinf(x: ArrayLike, /) -> Array:
"""Return a boolean array indicating whether each element of input is infinite.
@ -3359,6 +3466,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
return lax.full_like(x, False, dtype=np.bool_)
@export
def isposinf(x, /, out=None):
"""
Return boolean array indicating whether each element of input is positive infinite.
@ -3392,6 +3500,7 @@ def isposinf(x, /, out=None):
return _isposneginf(np.inf, x, out)
@export
def isneginf(x, /, out=None):
"""
Return boolean array indicating whether each element of input is negative infinite.
@ -3425,6 +3534,7 @@ def isneginf(x, /, out=None):
return _isposneginf(-np.inf, x, out)
@export
@partial(jit, inline=True)
def isnan(x: ArrayLike, /) -> Array:
"""Returns a boolean array indicating whether each element of input is ``NaN``.
@ -3459,6 +3569,7 @@ def isnan(x: ArrayLike, /) -> Array:
return lax.ne(x, x)
@export
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Compute the heaviside step function.
@ -3508,6 +3619,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
@export
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""
@ -3556,6 +3668,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(idx_inf, _lax_const(x, np.inf), x)
@export
@partial(jit, inline=True)
def reciprocal(x: ArrayLike, /) -> Array:
"""Calculate element-wise reciprocal of the input.
@ -3589,6 +3702,7 @@ def reciprocal(x: ArrayLike, /) -> Array:
return lax.integer_pow(x, -1)
@export
@jit
def sinc(x: ArrayLike, /) -> Array:
r"""Calculate the normalized sinc function.

View File

@ -23,9 +23,11 @@ from jax._src import api
from jax._src import config
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import safe_map as map, safe_zip as zip
from jax._src.util import set_module, safe_map as map, safe_zip as zip
export = set_module('jax.numpy')
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
_DIMENSION_NAME = r'\w+'
_CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME)
@ -185,6 +187,7 @@ def _apply_excluded(func: Callable[..., Any],
return new_func, dynamic_args, dynamic_kwargs
@export
def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
"""Define a vectorized function with broadcasting.

View File

@ -32,6 +32,14 @@ class PackageStructureTest(jtu.JaxTestCase):
@parameterized.parameters([
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors", exclude=["JaxRuntimeError"]),
_mod(
"jax.numpy",
exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating",
"dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo",
"flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim",
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
@ -46,7 +54,8 @@ class PackageStructureTest(jtu.JaxTestCase):
if name not in include and (name.startswith('_') or name in exclude):
continue
obj = getattr(module, name)
if isinstance(obj, types.ModuleType):
if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)):
# No __module__ attribute expected.
continue
self.assertEqual(obj.__module__, module_name,
f"{obj} has {obj.__module__=}, expected {module_name}")