Rename _wraps to implements

This commit is contained in:
Jake VanderPlas 2024-01-24 14:14:19 -08:00
parent 4646c64f54
commit 43a9faa06a
47 changed files with 569 additions and 571 deletions

View File

@ -22,7 +22,7 @@ from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, implements
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
@ -105,28 +105,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
return transformed
@_wraps(np.fft.fftn)
@implements(np.fft.fftn)
def fftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
@_wraps(np.fft.ifftn)
@implements(np.fft.ifftn)
def ifftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
@_wraps(np.fft.rfftn)
@implements(np.fft.rfftn)
def rfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
@_wraps(np.fft.irfftn)
@implements(np.fft.irfftn)
def irfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
@ -150,31 +150,31 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType,
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft)
@implements(np.fft.fft)
def fft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis,
norm=norm)
@_wraps(np.fft.ifft)
@implements(np.fft.ifft)
def ifft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis,
norm=norm)
@_wraps(np.fft.rfft)
@implements(np.fft.rfft)
def rfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis,
norm=norm)
@_wraps(np.fft.irfft)
@implements(np.fft.irfft)
def irfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis,
norm=norm)
@_wraps(np.fft.hfft)
@implements(np.fft.hfft)
def hfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
conj_a = ufuncs.conj(a)
@ -183,7 +183,7 @@ def hfft(a: ArrayLike, n: int | None = None,
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis,
norm=norm) * nn
@_wraps(np.fft.ihfft)
@implements(np.fft.ihfft)
def ihfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
_axis_check_1d('ihfft', axis)
@ -206,32 +206,32 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft2)
@implements(np.fft.fft2)
def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.ifft2)
@implements(np.fft.ifft2)
def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.rfft2)
@implements(np.fft.rfft2)
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.irfft2)
@implements(np.fft.irfft2)
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.fftfreq, extra_params="""
@implements(np.fft.fftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
@ -266,7 +266,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
return k / jnp.array(d * n, dtype=dtype)
@_wraps(np.fft.rfftfreq, extra_params="""
@implements(np.fft.rfftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
@ -292,7 +292,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
return k / jnp.array(d * n, dtype=dtype)
@_wraps(np.fft.fftshift)
@implements(np.fft.fftshift)
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("fftshift", x)
x = jnp.asarray(x)
@ -308,7 +308,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
return jnp.roll(x, shift, axes)
@_wraps(np.fft.ifftshift)
@implements(np.fft.ifftshift)
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("ifftshift", x)
x = jnp.asarray(x)

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike
from jax._src.numpy.util import implements, promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
@ -63,7 +63,7 @@ def _H(x: ArrayLike) -> Array:
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky)
@implements(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.cholesky", a)
@ -86,7 +86,7 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | SVDResult: ...
@_wraps(np.linalg.svd)
@implements(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | SVDResult:
@ -115,7 +115,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False)
@_wraps(np.linalg.matrix_power)
@implements(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
check_arraylike("jnp.linalg.matrix_power", a)
@ -154,7 +154,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
return result
@_wraps(np.linalg.matrix_rank)
@implements(np.linalg.matrix_rank)
@jit
def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
check_arraylike("jnp.linalg.matrix_rank", M)
@ -211,7 +211,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
return sign_diag * sign_taus, log_abs_det
@_wraps(
@implements(
np.linalg.slogdet,
extra_params=textwrap.dedent("""
method: string, optional
@ -357,7 +357,7 @@ def _det_3x3(a: Array) -> Array:
@custom_jvp
@_wraps(np.linalg.det)
@implements(np.linalg.det)
@jit
def det(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.det", a)
@ -383,7 +383,7 @@ def _det_jvp(primals, tangents):
return y, jnp.trace(z, axis1=-1, axis2=-2)
@_wraps(np.linalg.eig, lax_description="""
@implements(np.linalg.eig, lax_description="""
This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input,
and ``complex128`` for 64-bit input.
@ -399,7 +399,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
return w, v
@_wraps(np.linalg.eigvals)
@implements(np.linalg.eigvals)
@jit
def eigvals(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.eigvals", a)
@ -407,7 +407,7 @@ def eigvals(a: ArrayLike) -> Array:
compute_right_eigenvectors=False)[0]
@_wraps(np.linalg.eigh)
@implements(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: str | None = None,
symmetrize_input: bool = True) -> EighResult:
@ -425,7 +425,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
return EighResult(w, v)
@_wraps(np.linalg.eigvalsh)
@implements(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
check_arraylike("jnp.linalg.eigvalsh", a)
@ -434,7 +434,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
@partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
@implements(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
@ -494,7 +494,7 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
return p, p_dot
@_wraps(np.linalg.inv)
@implements(np.linalg.inv)
@jit
def inv(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.inv", a)
@ -506,7 +506,7 @@ def inv(a: ArrayLike) -> Array:
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
@_wraps(np.linalg.norm)
@implements(np.linalg.norm)
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x: ArrayLike, ord: int | str | None = None,
axis: None | tuple[int, ...] | int = None,
@ -608,7 +608,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
@overload
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...
@_wraps(np.linalg.qr)
@implements(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
check_arraylike("jnp.linalg.qr", a)
@ -628,7 +628,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
return QRResult(q, r)
@_wraps(np.linalg.solve)
@implements(np.linalg.solve)
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
@ -689,7 +689,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
@implements(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
@ -710,7 +710,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
return _jit_lstsq(a, b, rcond)
@_wraps(getattr(np.linalg, "cross", None))
@implements(getattr(np.linalg, "cross", None))
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
@ -722,7 +722,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
return jnp.cross(x1, x2, axis=axis)
@_wraps(getattr(np.linalg, "outer", None))
@implements(getattr(np.linalg, "outer", None))
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
@ -731,7 +731,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return x1[:, None] * x2[None, :]
@_wraps(getattr(np.linalg, "matrix_norm", None))
@implements(getattr(np.linalg, "matrix_norm", None))
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
"""
Computes the matrix norm of a matrix (or a stack of matrices) x.
@ -740,7 +740,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') ->
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))
@_wraps(getattr(np.linalg, "matrix_transpose", None))
@implements(getattr(np.linalg, "matrix_transpose", None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes a matrix (or a stack of matrices) x."""
check_arraylike('jnp.linalg.matrix_transpose', x)
@ -751,7 +751,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))
@_wraps(getattr(np.linalg, "vector_norm", None))
@implements(getattr(np.linalg, "vector_norm", None))
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Computes the vector norm of a vector (or batch of vectors) x."""
@ -764,31 +764,31 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
@_wraps(getattr(np.linalg, "vecdot", None))
@implements(getattr(np.linalg, "vecdot", None))
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
return jnp.vecdot(x1, x2, axis=axis)
@_wraps(getattr(np.linalg, "matmul", None))
@implements(getattr(np.linalg, "matmul", None))
def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike('jnp.linalg.matmul', x1, x2)
return jnp.matmul(x1, x2)
@_wraps(getattr(np.linalg, "tensordot", None))
@implements(getattr(np.linalg, "tensordot", None))
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
check_arraylike('jnp.linalg.tensordot', x1, x2)
return jnp.tensordot(x1, x2, axes=axes)
@_wraps(getattr(np.linalg, "svdvals", None))
@implements(getattr(np.linalg, "svdvals", None))
def svdvals(x: ArrayLike, /) -> Array:
check_arraylike('jnp.linalg.svdvals', x)
return svd(x, compute_uv=False, hermitian=False)
@_wraps(getattr(np.linalg, "diagonal", None))
@implements(getattr(np.linalg, "diagonal", None))
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
check_arraylike('jnp.linalg.diagonal', x)
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)

View File

@ -31,7 +31,7 @@ from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, _wraps)
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, implements)
from jax._src.typing import Array, ArrayLike
@ -57,7 +57,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
@_wraps(np.roots, lax_description="""\
@implements(np.roots, lax_description="""\
Unlike the numpy version of this function, the JAX version returns the roots in
a complex array regardless of the values of the roots. Additionally, the jax
version of this function adds the ``strip_zeros`` function which must be set to
@ -106,7 +106,7 @@ _POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@implements(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
full: bool = False, w: Array | None = None, cov: bool = False
@ -187,7 +187,7 @@ np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@_wraps(np.poly, lax_description=_POLY_DOC)
@implements(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros: Array) -> Array:
check_arraylike('poly', seq_of_zeros)
@ -214,7 +214,7 @@ def poly(seq_of_zeros: Array) -> Array:
return a
@_wraps(np.polyval, lax_description="""\
@implements(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
@ -231,7 +231,7 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
@_wraps(np.polyadd)
@implements(np.polyadd)
@jit
def polyadd(a1: Array, a2: Array) -> Array:
check_arraylike("polyadd", a1, a2)
@ -242,7 +242,7 @@ def polyadd(a1: Array, a2: Array) -> Array:
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyint)
@implements(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
@ -265,7 +265,7 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
return true_divide(concatenate((p, k_arr)), coeff)
@_wraps(np.polyder)
@implements(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p: Array, m: int = 1) -> Array:
check_arraylike("polyder", p)
@ -288,7 +288,7 @@ considered zero may lead to inconsistent results between NumPy and JAX, and even
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
@implements(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
@ -300,7 +300,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
a2_arr = asarray([0], dtype=a1_arr.dtype)
return convolve(a1_arr, a2_arr, mode='full')
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
@implements(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
@ -317,7 +317,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
return q, u_arr
@_wraps(np.polysub)
@implements(np.polysub)
@jit
def polysub(a1: Array, a2: Array) -> Array:
check_arraylike("polysub", a1, a2)

View File

@ -31,7 +31,7 @@ from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
@ -219,7 +219,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
@ -238,7 +238,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
@ -256,7 +256,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.max, skip_params=['out'])
@implements(np.max, skip_params=['out'])
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -271,7 +271,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.min, skip_params=['out'])
@implements(np.min, skip_params=['out'])
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@ -284,7 +284,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.all, skip_params=['out'])
@implements(np.all, skip_params=['out'])
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
@ -296,7 +296,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.any, skip_params=['out'])
@implements(np.any, skip_params=['out'])
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
@ -316,7 +316,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@_wraps(np.mean, skip_params=['out'])
@implements(np.mean, skip_params=['out'])
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
@ -365,7 +365,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
@_wraps(np.average)
@implements(np.average)
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@ -425,7 +425,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
return avg
@_wraps(np.var, skip_params=['out'])
@implements(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
@ -486,7 +486,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
return _upcast_f16(computation_dtype), np.dtype(dtype)
@_wraps(np.std, skip_params=['out'])
@implements(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
@ -506,7 +506,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
@_wraps(np.ptp, skip_params=['out'])
@implements(np.ptp, skip_params=['out'])
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@ -522,7 +522,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
return lax.sub(x, y)
@_wraps(np.count_nonzero)
@implements(np.count_nonzero)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
@ -546,7 +546,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
else:
return out
@_wraps(np.nanmin, skip_params=['out'])
@implements(np.nanmin, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -555,7 +555,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmax, skip_params=['out'])
@implements(np.nanmax, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -564,7 +564,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nansum, skip_params=['out'])
@implements(np.nansum, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -578,7 +578,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
if nansum.__doc__ is not None:
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@implements(np.nanprod, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@ -588,7 +588,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmean, skip_params=['out'])
@implements(np.nanmean, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, where: ArrayLike | None = None) -> Array:
@ -608,7 +608,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
return td
@_wraps(np.nanvar, skip_params=['out'])
@implements(np.nanvar, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@ -639,7 +639,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
return lax.convert_element_type(result, dtype)
@_wraps(np.nanstd, skip_params=['out'])
@implements(np.nanstd, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@ -664,7 +664,7 @@ match the dtype of the input.
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
@_wraps(np_reduction, skip_params=['out'],
@implements(np_reduction, skip_params=['out'],
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -709,7 +709,7 @@ nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
# Quantiles
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
@ -725,7 +725,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@implements(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
@ -862,7 +862,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@implements(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
@ -874,7 +874,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@implements(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
@ -887,7 +887,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
interpolation=interpolation, method=method,
keepdims=keepdims)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@implements(np.median, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
@ -896,7 +896,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@implements(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,

View File

@ -34,7 +34,7 @@ from jax._src.numpy.lax_numpy import (
sort, where, zeros)
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, implements
from jax._src.util import canonicalize_axis
from jax._src.typing import Array, ArrayLike
@ -61,7 +61,7 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
else:
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
@_wraps(np.setdiff1d,
@implements(np.setdiff1d,
lax_description=_dedent("""
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
@ -98,7 +98,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
@_wraps(np.union1d,
@implements(np.union1d,
lax_description=_dedent("""
Because the size of the output of ``union1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
@ -125,7 +125,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
return cast(Array, out)
@_wraps(np.setxor1d, lax_description="""
@implements(np.setxor1d, lax_description="""
In the JAX version, the input arrays are explicitly flattened regardless
of assume_unique value.
""")
@ -169,7 +169,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
return aux, mask
@_wraps(np.intersect1d)
@implements(np.intersect1d)
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False) -> Array | tuple[Array, Array, Array]:
check_arraylike("intersect1d", ar1, ar2)
@ -206,7 +206,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return int1d
@_wraps(np.isin, lax_description="""
@implements(np.isin, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
def isin(element: ArrayLike, test_elements: ArrayLike,
@ -312,7 +312,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
ret += (mask.sum(),)
return ret[0] if len(ret) == 1 else ret
@_wraps(np.unique, skip_params=['axis'],
@implements(np.unique, skip_params=['axis'],
lax_description=_dedent("""
Because the size of the output of ``unique`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
@ -368,7 +368,7 @@ class _UniqueInverseResult(NamedTuple):
inverse_indices: Array
@_wraps(getattr(np, "unique_all", None))
@implements(getattr(np, "unique_all", None))
def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
check_arraylike("unique_all", x)
values, indices, inverse_indices, counts = unique(
@ -376,21 +376,21 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
@_wraps(getattr(np, "unique_counts", None))
@implements(getattr(np, "unique_counts", None))
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
check_arraylike("unique_counts", x)
values, counts = unique(x, return_counts=True, equal_nan=False)
return _UniqueCountsResult(values=values, counts=counts)
@_wraps(getattr(np, "unique_inverse", None))
@implements(getattr(np, "unique_inverse", None))
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
check_arraylike("unique_inverse", x)
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
@_wraps(getattr(np, "unique_values", None))
@implements(getattr(np, "unique_values", None))
def unique_values(x: ArrayLike, /) -> Array:
check_arraylike("unique_values", x)
return cast(Array, unique(x, equal_nan=False))

View File

@ -27,7 +27,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis, set_module
import numpy as np
@ -131,7 +131,7 @@ class ufunc:
raise NotImplementedError(f"where argument of {self}")
return self._call(*args, **kwargs)
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
@implements(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
@ -219,7 +219,7 @@ class ufunc:
result = result.reshape(final_shape)
return result
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@implements(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None) -> Array:
@ -257,7 +257,7 @@ class ufunc:
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
return _moveaxis(result, 0, axis)
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@implements(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: bool = True) -> Array:
@ -296,7 +296,7 @@ class ufunc:
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
return carry[1]
@_wraps(np.ufunc.reduceat, module="numpy.ufunc")
@implements(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@ -335,7 +335,7 @@ class ufunc:
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
@_wraps(np.ufunc.outer, module="numpy.ufunc")
@implements(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0])
def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array:
if self.nin != 2:

View File

@ -34,7 +34,7 @@ from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps, check_no_float0s)
promote_shapes, _where, implements, check_no_float0s)
_lax_const = lax._const
@ -68,9 +68,9 @@ def _one_to_one_unop(
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
return implements(numpy_fn, module='numpy')(fn)
def _one_to_one_binop(
@ -87,9 +87,9 @@ def _one_to_one_binop(
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
return implements(numpy_fn, module='numpy')(fn)
def _maybe_bool_binop(
@ -102,9 +102,9 @@ def _maybe_bool_binop(
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
return implements(numpy_fn, module='numpy')(fn)
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
@ -120,7 +120,7 @@ def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
return lax_fn(x1, x2)
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
return _wraps(numpy_fn, module='numpy')(fn)
return implements(numpy_fn, module='numpy')(fn)
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ...
@ -130,7 +130,7 @@ def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ...
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ...
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp:
@_wraps(np_op, update_doc=False, module='numpy')
@implements(np_op, update_doc=False, module='numpy')
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
@ -214,14 +214,14 @@ atanh = _one_to_one_unop(getattr(np, "atanh", np.arctanh), lax.atanh, True)
atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True)
@_wraps(getattr(np, 'bitwise_count', None), module='numpy')
@implements(getattr(np, 'bitwise_count', None), module='numpy')
@jit
def bitwise_count(x: ArrayLike, /) -> Array:
x, = promote_args_numeric("bitwise_count", x)
# Following numpy we take the absolute value and return uint8.
return lax.population_count(abs(x)).astype('uint8')
@_wraps(np.right_shift, module='numpy')
@implements(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
@ -229,7 +229,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
@implements(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
@partial(jit, inline=True)
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2)
@ -237,16 +237,16 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(np.absolute, module='numpy')
@implements(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
abs = implements(np.abs, module='numpy')(absolute)
@_wraps(np.rint, module='numpy')
@implements(np.rint, module='numpy')
@jit
def rint(x: ArrayLike, /) -> Array:
check_arraylike('rint', x)
@ -258,7 +258,7 @@ def rint(x: ArrayLike, /) -> Array:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
@_wraps(np.copysign, module='numpy')
@implements(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("copysign", x1, x2)
@ -267,7 +267,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
@_wraps(np.true_divide, module='numpy')
@implements(np.true_divide, module='numpy')
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("true_divide", x1, x2)
@ -276,7 +276,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
divide = true_divide
@_wraps(np.floor_divide, module='numpy')
@implements(np.floor_divide, module='numpy')
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
@ -301,7 +301,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _float_divmod(x1, x2)[0]
@_wraps(np.divmod, module='numpy')
@implements(np.divmod, module='numpy')
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
x1, x2 = promote_args_numeric("divmod", x1, x2)
@ -323,7 +323,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod
@_wraps(np.power, module='numpy')
@implements(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
check_no_float0s("power", x1, x2)
@ -393,7 +393,7 @@ def _pow_int_int(x1, x2):
@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@implements(np.logaddexp, module='numpy')
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("logaddexp", x1, x2)
@ -431,7 +431,7 @@ def _logaddexp_jvp(primals, tangents):
@custom_jvp
@_wraps(np.logaddexp2, module='numpy')
@implements(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
@ -459,28 +459,28 @@ def _logaddexp2_jvp(primals, tangents):
return primal_out, tangent_out
@_wraps(np.log2, module='numpy')
@implements(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10, module='numpy')
@implements(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2, module='numpy')
@implements(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("exp2", x)
return lax.exp2(x)
@_wraps(np.signbit, module='numpy')
@implements(np.signbit, module='numpy')
@jit
def signbit(x: ArrayLike, /) -> Array:
x, = promote_args("signbit", x)
@ -511,7 +511,7 @@ def _normalize_float(x):
return lax.bitcast_convert_type(x1, int_type), x2
@_wraps(np.ldexp, module='numpy')
@implements(np.ldexp, module='numpy')
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("ldexp", x1, x2)
@ -560,7 +560,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
@_wraps(np.frexp, module='numpy')
@implements(np.frexp, module='numpy')
@jit
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
check_arraylike("frexp", x)
@ -584,7 +584,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
@_wraps(np.remainder, module='numpy')
@implements(np.remainder, module='numpy')
@jit
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("remainder", x1, x2)
@ -596,10 +596,10 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod, module='numpy')(remainder)
mod = implements(np.mod, module='numpy')(remainder)
@_wraps(np.fmod, module='numpy')
@implements(np.fmod, module='numpy')
@jit
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("fmod", x1, x2)
@ -608,7 +608,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.rem(*promote_args_numeric("fmod", x1, x2))
@_wraps(np.square, module='numpy')
@implements(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
check_arraylike("square", x)
@ -616,14 +616,14 @@ def square(x: ArrayLike, /) -> Array:
return lax.integer_pow(x, 2)
@_wraps(np.deg2rad, module='numpy')
@implements(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))
@_wraps(np.rad2deg, module='numpy')
@implements(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("rad2deg", x)
@ -634,7 +634,7 @@ degrees = rad2deg
radians = deg2rad
@_wraps(np.conjugate, module='numpy')
@implements(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
check_arraylike("conjugate", x)
@ -642,20 +642,20 @@ def conjugate(x: ArrayLike, /) -> Array:
conj = conjugate
@_wraps(np.imag)
@implements(np.imag)
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@_wraps(np.real)
@implements(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
@_wraps(np.modf, module='numpy', skip_params=['out'])
@implements(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
check_arraylike("modf", x)
@ -666,7 +666,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
return x - whole, whole
@_wraps(np.isfinite, module='numpy')
@implements(np.isfinite, module='numpy')
@partial(jit, inline=True)
def isfinite(x: ArrayLike, /) -> Array:
check_arraylike("isfinite", x)
@ -679,7 +679,7 @@ def isfinite(x: ArrayLike, /) -> Array:
return lax.full_like(x, True, dtype=np.bool_)
@_wraps(np.isinf, module='numpy')
@implements(np.isinf, module='numpy')
@jit
def isinf(x: ArrayLike, /) -> Array:
check_arraylike("isinf", x)
@ -707,24 +707,24 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
return lax.full_like(x, False, dtype=np.bool_)
isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])(
isposinf: UnOp = implements(np.isposinf, skip_params=['out'])(
lambda x, /, out=None: _isposneginf(np.inf, x, out)
)
isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
isneginf: UnOp = implements(np.isneginf, skip_params=['out'])(
lambda x, /, out=None: _isposneginf(-np.inf, x, out)
)
@_wraps(np.isnan, module='numpy')
@implements(np.isnan, module='numpy')
@partial(jit, inline=True)
def isnan(x: ArrayLike, /) -> Array:
check_arraylike("isnan", x)
return lax.ne(x, x)
@_wraps(np.heaviside, module='numpy')
@implements(np.heaviside, module='numpy')
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("heaviside", x1, x2)
@ -734,7 +734,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
@_wraps(np.hypot, module='numpy')
@implements(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("hypot", x1, x2)
@ -745,7 +745,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
@_wraps(np.reciprocal, module='numpy')
@implements(np.reciprocal, module='numpy')
@partial(jit, inline=True)
def reciprocal(x: ArrayLike, /) -> Array:
check_arraylike("reciprocal", x)
@ -753,7 +753,7 @@ def reciprocal(x: ArrayLike, /) -> Array:
return lax.integer_pow(x, -1)
@_wraps(np.sinc, update_doc=False)
@implements(np.sinc, update_doc=False)
@jit
def sinc(x: ArrayLike, /) -> Array:
check_arraylike("sinc", x)

View File

@ -112,13 +112,13 @@ def _parse_parameters(body: str) -> dict[str, str]:
def _parse_extra_params(extra_params: str) -> dict[str, str]:
"""Parse the extra parameters passed to _wraps()"""
"""Parse the extra parameters passed to implements()"""
parameters = _parameter_break.split(extra_params.strip('\n'))
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
def _wraps(
fun: Callable[..., Any] | None,
def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
@ -126,46 +126,46 @@ def _wraps(
extra_params: str | None = None,
module: str | None = None,
) -> Callable[[_T], _T]:
"""Specialized version of functools.wraps for wrapping numpy functions.
"""Decorator for JAX functions which implement a specified NumPy function.
This produces a wrapped function with a modified docstring. In particular, if
`update_doc` is True, parameters listed in the wrapped function that are not
supported by the decorated function will be removed from the docstring. For
this reason, it is important that parameter names match those in the original
numpy function.
This mainly contains logic to copy and modify the docstring of the original
function. In particular, if `update_doc` is True, parameters listed in the
original function that are not supported by the decorated function will
be removed from the docstring. For this reason, it is important that parameter
names match those in the original numpy function.
Args:
fun: The function being wrapped
original_fun: The original function being implemented
update_doc: whether to transform the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version.
If False, include the numpy docstring verbatim.
lax_description: a string description that will be added to the beginning of
the docstring.
sections: a list of sections to include in the docstring. The default is
["Parameters", "returns", "References"]
["Parameters", "Returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
function that should be skipped in the parameter list.
extra_params: an optional string containing additional parameter descriptions.
When ``update_doc=True``, these will be added to the list of parameter
descriptions in the updated doc.
module: an optional string specifying the module from which the wrapped function
module: an optional string specifying the module from which the original function
is imported. This is useful for objects such as ufuncs, where the module cannot
be determined from the wrapped function itself.
be determined from the original function itself.
"""
def wrap(op):
op.__np_wrapped__ = fun
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
if fun is None:
def decorator(wrapped_fun):
wrapped_fun.__np_wrapped__ = original_fun
# Allows this pattern: @implements(getattr(np, 'new_function', None))
if original_fun is None:
if lax_description:
op.__doc__ = lax_description
return op
docstr = getattr(fun, "__doc__", None)
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))
wrapped_fun.__doc__ = lax_description
return wrapped_fun
docstr = getattr(original_fun, "__doc__", None)
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
try:
mod = module or fun.__module__
mod = module or original_fun.__module__
except AttributeError:
if config.enable_checks.value:
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().")
else:
name = f"{mod}.{name}"
if docstr:
@ -173,7 +173,7 @@ def _wraps(
parsed = _parse_numpydoc(docstr)
if update_doc and 'Parameters' in parsed.sections:
code = getattr(getattr(op, "__wrapped__", op), "__code__", None)
code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None)
# Remove unrecognized parameter descriptions.
parameters = _parse_parameters(parsed.sections['Parameters'])
if extra_params:
@ -211,18 +211,18 @@ def _wraps(
except:
if config.enable_checks.value:
raise
docstr = fun.__doc__
docstr = original_fun.__doc__
op.__doc__ = docstr
wrapped_fun.__doc__ = docstr
for attr in ['__name__', '__qualname__']:
try:
value = getattr(fun, attr)
value = getattr(original_fun, attr)
except AttributeError:
pass
else:
setattr(op, attr, value)
return op
return wrap
setattr(wrapped_fun, attr, value)
return wrapped_fun
return decorator
_dtype = partial(dtypes.dtype, canonicalize=True)

View File

@ -19,7 +19,7 @@ import textwrap
from jax import vmap
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, check_arraylike, promote_dtypes_inexact
from jax._src.numpy.util import implements, check_arraylike, promote_dtypes_inexact
_no_chkfinite_doc = textwrap.dedent("""
@ -28,7 +28,7 @@ because compiled JAX code cannot perform checks of array values at runtime
""")
@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
@implements(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
def vq(obs, code_book, check_finite=True):
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
if obs.ndim != code_book.ndim:

View File

@ -22,7 +22,7 @@ import scipy.fft as osp_fft
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _wraps, promote_dtypes_complex
from jax._src.numpy.util import implements, promote_dtypes_complex
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
@ -42,7 +42,7 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
@_wraps(osp_fft.dct)
@implements(osp_fft.dct)
def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
if type != 2:
@ -81,7 +81,7 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
return out
@_wraps(osp_fft.dctn)
@implements(osp_fft.dctn)
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
@ -109,7 +109,7 @@ def dctn(x: Array, type: int = 2,
return x
@_wraps(osp_fft.dct)
@implements(osp_fft.dct)
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
if type != 2:
@ -139,7 +139,7 @@ def idct(x: Array, type: int = 2, n: int | None = None,
out = _dct_deinterleave(x.real, axis)
return out
@_wraps(osp_fft.idctn)
@implements(osp_fft.idctn)
def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,

View File

@ -23,7 +23,7 @@ from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike
import jax.numpy as jnp
@util._wraps(scipy.integrate.trapezoid)
@util.implements(scipy.integrate.trapezoid)
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:

View File

@ -29,7 +29,7 @@ from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.util import (
check_arraylike, _wraps, promote_dtypes, promote_dtypes_inexact,
check_arraylike, implements, promote_dtypes, promote_dtypes_inexact,
promote_dtypes_complex)
from jax._src.typing import Array, ArrayLike
@ -46,14 +46,14 @@ def _cholesky(a: ArrayLike, lower: bool) -> Array:
l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False)
return l if lower else jnp.conj(l.mT)
@_wraps(scipy.linalg.cholesky,
@implements(scipy.linalg.cholesky,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Array:
del overwrite_a, check_finite # Unused
return _cholesky(a, lower)
@_wraps(scipy.linalg.cho_factor,
@implements(scipy.linalg.cho_factor,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, bool]:
@ -70,7 +70,7 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
transpose_a=lower, conjugate_a=lower)
return b
@_wraps(scipy.linalg.cho_solve, update_doc=False,
@implements(scipy.linalg.cho_solve, update_doc=False,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike,
overwrite_b: bool = False, check_finite: bool = True) -> Array:
@ -112,7 +112,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ...
@_wraps(scipy.linalg.svd,
@implements(scipy.linalg.svd,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
@ -120,7 +120,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
del overwrite_a, check_finite, lapack_driver # unused
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@_wraps(scipy.linalg.det,
@implements(scipy.linalg.det,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
del overwrite_a, check_finite # unused
@ -182,7 +182,7 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ...
@_wraps(scipy.linalg.eigh,
@implements(scipy.linalg.eigh,
lax_description=_no_overwrite_and_chkfinite_doc,
skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite'))
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
@ -198,21 +198,21 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]:
a = a.astype(dtypes.to_complex_dtype(a.dtype))
return lax_linalg.schur(a)
@_wraps(scipy.linalg.schur)
@implements(scipy.linalg.schur)
def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
return _schur(a, output)
@_wraps(scipy.linalg.inv,
@implements(scipy.linalg.inv,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
del overwrite_a, check_finite # unused
return jnp.linalg.inv(a)
@_wraps(scipy.linalg.lu_factor,
@implements(scipy.linalg.lu_factor,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]:
@ -222,7 +222,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True
return lu, pivots
@_wraps(scipy.linalg.lu_solve,
@implements(scipy.linalg.lu_solve,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite'))
def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
@ -269,7 +269,7 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False,
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
@_wraps(scipy.linalg.lu, update_doc=False,
@implements(scipy.linalg.lu, update_doc=False,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
@ -320,7 +320,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ...
@_wraps(scipy.linalg.qr,
@implements(scipy.linalg.qr,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork'))
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]:
@ -352,7 +352,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
@_wraps(scipy.linalg.solve,
@implements(scipy.linalg.solve,
lax_description=_no_overwrite_and_chkfinite_doc,
skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
@ -391,7 +391,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str,
else:
return out
@_wraps(scipy.linalg.solve_triangular,
@implements(scipy.linalg.solve_triangular,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite'))
def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False,
unit_diagonal: bool = False, overwrite_b: bool = False,
@ -414,7 +414,7 @@ where norm() denotes the L1 norm, and
- c=1.97 for float32 or complex64
""")
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
@implements(scipy.linalg.expm, lax_description=_expm_description)
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
A, = promote_dtypes_inexact(A)
@ -572,7 +572,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: bool = True) -> Array | tuple[Array, Array]: ...
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
@implements(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
@partial(jit, static_argnames=('method', 'compute_expm'))
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
compute_expm: bool = True) -> Array | tuple[Array, Array]:
@ -597,7 +597,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
return expm_frechet_AE
@_wraps(scipy.linalg.block_diag)
@implements(scipy.linalg.block_diag)
@jit
def block_diag(*arrs: ArrayLike) -> Array:
if len(arrs) == 0:
@ -619,7 +619,7 @@ def block_diag(*arrs: ArrayLike) -> Array:
return acc
@_wraps(scipy.linalg.eigh_tridiagonal)
@implements(scipy.linalg.eigh_tridiagonal)
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
select: str = 'a', select_range: tuple[float, float] | None = None,
@ -901,7 +901,7 @@ def _sqrtm(A: ArrayLike) -> Array:
return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST),
jnp.conj(Z.T), precision=lax.Precision.HIGHEST)
@_wraps(scipy.linalg.sqrtm,
@implements(scipy.linalg.sqrtm,
lax_description="""
This differs from ``scipy.linalg.sqrtm`` in that the return type of
``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input,
@ -918,7 +918,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array:
raise NotImplementedError("Blocked version is not implemented yet.")
return _sqrtm(A)
@_wraps(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
@implements(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
@partial(jit, static_argnames=('check_finite',))
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
del check_finite # unused
@ -987,7 +987,7 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> tuple[Array, Array]: ...
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
@implements(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Array | tuple[Array, Array]:
@ -1010,7 +1010,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
else:
return h
@_wraps(scipy.linalg.toeplitz)
@implements(scipy.linalg.toeplitz)
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
if r is None:
check_arraylike("toeplitz", c)

View File

@ -25,7 +25,7 @@ from jax._src import api
from jax._src import util
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip
@ -127,7 +127,7 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
return result.astype(input_arr.dtype)
@_wraps(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
@implements(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
Note that interpolation near boundaries differs from the scipy function,

View File

@ -34,13 +34,13 @@ from jax._src import dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, _wraps, promote_dtypes_inexact, promote_dtypes_complex)
check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex)
from jax._src.third_party.scipy import signal_helper
from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
@_wraps(osp_signal.fftconvolve)
@implements(osp_signal.fftconvolve)
def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full",
axes: Sequence[int] | None = None) -> Array:
check_arraylike('fftconvolve', in1, in2)
@ -133,7 +133,7 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike)
return result[0, 0]
@_wraps(osp_signal.convolve)
@implements(osp_signal.convolve)
def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
precision: PrecisionLike = None) -> Array:
if method == 'fft':
@ -144,7 +144,7 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.")
@_wraps(osp_signal.convolve2d)
@implements(osp_signal.convolve2d)
def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
if boundary != 'fill' or fillvalue != 0:
@ -154,13 +154,13 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill
return _convolve_nd(in1, in2, mode, precision=precision)
@_wraps(osp_signal.correlate)
@implements(osp_signal.correlate)
def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
precision: PrecisionLike = None) -> Array:
return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method)
@_wraps(osp_signal.correlate2d)
@implements(osp_signal.correlate2d)
def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
if boundary != 'fill' or fillvalue != 0:
@ -191,7 +191,7 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil
return result
@_wraps(osp_signal.detrend)
@implements(osp_signal.detrend)
def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0,
overwrite_data: None = None) -> Array:
if overwrite_data is not None:
@ -499,7 +499,7 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
return freqs, time, result
@_wraps(osp_signal.stft)
@implements(osp_signal.stft)
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
noverlap: int | None = None, nfft: int | None = None,
detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros',
@ -518,7 +518,7 @@ to follow the latter behavior. For using the former behavior, call this
function as `csd(x, None)`."""
@_wraps(osp_signal.csd, lax_description=_csd_description)
@implements(osp_signal.csd, lax_description=_csd_description)
def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, detrend: str = 'constant',
@ -551,7 +551,7 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann'
return freqs, Pxy
@_wraps(osp_signal.welch)
@implements(osp_signal.welch)
def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, detrend: str = 'constant',
@ -613,7 +613,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
return x.reshape(tuple(batch_shape) + (-1,))
@_wraps(osp_signal.istft)
@implements(osp_signal.istft)
def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, input_onesided: bool = True,

View File

@ -22,10 +22,10 @@ import scipy.spatial.transform
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
@_wraps(scipy.spatial.transform.Rotation)
@implements(scipy.spatial.transform.Rotation)
class Rotation(typing.NamedTuple):
"""Rotation in 3 dimensions."""
@ -169,7 +169,7 @@ class Rotation(typing.NamedTuple):
return self.quat.ndim == 1
@_wraps(scipy.spatial.transform.Slerp)
@implements(scipy.spatial.transform.Slerp)
class Slerp(typing.NamedTuple):
"""Spherical Linear Interpolation of Rotations."""

View File

@ -32,19 +32,19 @@ from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
from jax._src.typing import Array, ArrayLike
@_wraps(osp_special.gammaln, module='scipy.special')
@implements(osp_special.gammaln, module='scipy.special')
def gammaln(x: ArrayLike) -> Array:
x, = promote_args_inexact("gammaln", x)
return lax.lgamma(x)
@_wraps(osp_special.gamma, module='scipy.special', lax_description="""\
@implements(osp_special.gamma, module='scipy.special', lax_description="""\
The JAX version only accepts real-valued inputs.""")
def gamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("gamma", x)
@ -53,14 +53,14 @@ def gamma(x: ArrayLike) -> Array:
sign = jnp.where((x > 0) | (x == floor_x), 1.0, (-1.0) ** floor_x)
return sign * lax.exp(lax.lgamma(x))
betaln = _wraps(
betaln = implements(
osp_special.betaln,
module='scipy.special',
update_doc=False
)(_betaln_impl)
@_wraps(osp_special.factorial, module='scipy.special')
@implements(osp_special.factorial, module='scipy.special')
def factorial(n: ArrayLike, exact: bool = False) -> Array:
if exact:
raise NotImplementedError("factorial with exact=True")
@ -68,58 +68,58 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array:
return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1)))
@_wraps(osp_special.beta, module='scipy.special')
@implements(osp_special.beta, module='scipy.special')
def beta(x: ArrayLike, y: ArrayLike) -> Array:
x, y = promote_args_inexact("beta", x, y)
return lax.exp(betaln(x, y))
@_wraps(osp_special.betainc, module='scipy.special')
@implements(osp_special.betainc, module='scipy.special')
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
a, b, x = promote_args_inexact("betainc", a, b, x)
return lax.betainc(a, b, x)
@_wraps(osp_special.digamma, module='scipy.special', lax_description="""\
@implements(osp_special.digamma, module='scipy.special', lax_description="""\
The JAX version only accepts real-valued inputs.""")
def digamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("digamma", x)
return lax.digamma(x)
@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False)
@implements(osp_special.gammainc, module='scipy.special', update_doc=False)
def gammainc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = promote_args_inexact("gammainc", a, x)
return lax.igamma(a, x)
@_wraps(osp_special.gammaincc, module='scipy.special', update_doc=False)
@implements(osp_special.gammaincc, module='scipy.special', update_doc=False)
def gammaincc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = promote_args_inexact("gammaincc", a, x)
return lax.igammac(a, x)
@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"],
@implements(osp_special.erf, module='scipy.special', skip_params=["out"],
lax_description="Note that the JAX version does not support complex inputs.")
def erf(x: ArrayLike) -> Array:
x, = promote_args_inexact("erf", x)
return lax.erf(x)
@_wraps(osp_special.erfc, module='scipy.special', update_doc=False)
@implements(osp_special.erfc, module='scipy.special', update_doc=False)
def erfc(x: ArrayLike) -> Array:
x, = promote_args_inexact("erfc", x)
return lax.erfc(x)
@_wraps(osp_special.erfinv, module='scipy.special')
@implements(osp_special.erfinv, module='scipy.special')
def erfinv(x: ArrayLike) -> Array:
x, = promote_args_inexact("erfinv", x)
return lax.erf_inv(x)
@custom_derivatives.custom_jvp
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
@implements(osp_special.logit, module='scipy.special', update_doc=False)
def logit(x: ArrayLike) -> Array:
x, = promote_args_inexact("logit", x)
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
@ -127,17 +127,17 @@ logit.defjvps(
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
@implements(osp_special.expit, module='scipy.special', update_doc=False)
def expit(x: ArrayLike) -> Array:
x, = promote_args_inexact("expit", x)
return lax.logistic(x)
logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
logsumexp = implements(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlogy, module='scipy.special')
@implements(osp_special.xlogy, module='scipy.special')
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlogy(0, 0) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlogy", x, y)
@ -153,7 +153,7 @@ xlogy.defjvp(_xlogy_jvp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
@implements(osp_special.xlog1py, module='scipy.special', update_doc=False)
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlog1py", x, y)
@ -179,14 +179,14 @@ def _xlogx_jvp(primals, tangents):
_xlogx.defjvp(_xlogx_jvp)
@_wraps(osp_special.entr, module='scipy.special')
@implements(osp_special.entr, module='scipy.special')
def entr(x: ArrayLike) -> Array:
x, = promote_args_inexact("entr", x)
return lax.select(lax.lt(x, _lax_const(x, 0)),
lax.full_like(x, -np.inf),
lax.neg(_xlogx(x)))
@_wraps(osp_special.multigammaln, update_doc=False)
@implements(osp_special.multigammaln, update_doc=False)
def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d_ = promote_args_inexact("multigammaln", a, d)
@ -201,7 +201,7 @@ def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
return res + constant
@_wraps(osp_special.kl_div, module="scipy.special")
@implements(osp_special.kl_div, module="scipy.special")
def kl_div(
p: ArrayLike,
q: ArrayLike,
@ -227,7 +227,7 @@ def kl_div(
return result
@_wraps(osp_special.rel_entr, module="scipy.special")
@implements(osp_special.rel_entr, module="scipy.special")
def rel_entr(
p: ArrayLike,
q: ArrayLike,
@ -268,7 +268,7 @@ _BERNOULLI_COEFS = [
@custom_derivatives.custom_jvp
@_wraps(osp_special.zeta, module='scipy.special')
@implements(osp_special.zeta, module='scipy.special')
def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array:
if q is None:
raise NotImplementedError(
@ -311,7 +311,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type]
@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False)
@implements(osp_special.polygamma, module='scipy.special', update_doc=False)
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
assert jnp.issubdtype(lax.dtype(n), jnp.integer)
n_arr, x_arr = promote_args_inexact("polygamma", n, x)
@ -725,22 +725,22 @@ def _norm_logpdf(x):
log_normalizer = _lax_const(x, _norm_logpdf_constant)
return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
@_wraps(osp_special.i0e, module='scipy.special')
@implements(osp_special.i0e, module='scipy.special')
def i0e(x: ArrayLike) -> Array:
x, = promote_args_inexact("i0e", x)
return lax.bessel_i0e(x)
@_wraps(osp_special.i0, module='scipy.special')
@implements(osp_special.i0, module='scipy.special')
def i0(x: ArrayLike) -> Array:
x, = promote_args_inexact("i0", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
@_wraps(osp_special.i1e, module='scipy.special')
@implements(osp_special.i1e, module='scipy.special')
def i1e(x: ArrayLike) -> Array:
x, = promote_args_inexact("i1e", x)
return lax.bessel_i1e(x)
@_wraps(osp_special.i1, module='scipy.special')
@implements(osp_special.i1, module='scipy.special')
def i1(x: ArrayLike) -> Array:
x, = promote_args_inexact("i1", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
@ -1459,7 +1459,7 @@ def _expi_neg(x: Array) -> Array:
@custom_derivatives.custom_jvp
@jit
@_wraps(osp_special.expi, module='scipy.special')
@implements(osp_special.expi, module='scipy.special')
def expi(x: ArrayLike) -> Array:
x_arr, = promote_args_inexact("expi", x)
return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos])
@ -1577,7 +1577,7 @@ def _expn3(n: int, x: Array) -> Array:
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,))
@jnp.vectorize
@_wraps(osp_special.expn, module='scipy.special')
@implements(osp_special.expn, module='scipy.special')
@jit
def expn(n: ArrayLike, x: ArrayLike) -> Array:
n, x = promote_args_inexact("expn", n, x)
@ -1615,7 +1615,7 @@ def expn_jvp(n, primals, tangents):
)
@_wraps(osp_special.exp1, module="scipy.special")
@implements(osp_special.exp1, module="scipy.special")
def exp1(x: ArrayLike, module='scipy.special') -> Array:
x, = promote_args_inexact("exp1", x)
# Casting because custom_jvp generic does not work correctly with mypy.
@ -1716,7 +1716,7 @@ def spence(x: Array) -> Array:
return _spence(x)
@_wraps(osp_special.bernoulli, module='scipy.special')
@implements(osp_special.bernoulli, module='scipy.special')
def bernoulli(n: int) -> Array:
# Generate Bernoulli numbers using the Chowla and Hartung algorithm.
n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli")
@ -1734,7 +1734,7 @@ def bernoulli(n: int) -> Array:
@custom_derivatives.custom_jvp
@_wraps(osp_special.poch, module='scipy.special', lax_description="""\
@implements(osp_special.poch, module='scipy.special', lax_description="""\
The JAX version only accepts positive and real inputs.""")
def poch(z: ArrayLike, m: ArrayLike) -> Array:
# Factorial definition when m is close to an integer, otherwise gamma definition.
@ -1883,7 +1883,7 @@ def _hyp1f1_x_derivative(a, b, x):
@custom_derivatives.custom_jvp
@jit
@jnp.vectorize
@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\
@implements(osp_special.hyp1f1, module='scipy.special', lax_description="""\
The JAX version only accepts positive and real inputs. Values of a, b and x
leading to high values of 1F1 might be erroneous, considering enabling double
precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""")

View File

@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact
from jax._src.numpy.util import check_arraylike, implements, promote_args_inexact
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis
@ -31,7 +31,7 @@ import scipy
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
@_wraps(scipy.stats.mode, lax_description="""\
@implements(scipy.stats.mode, lax_description="""\
Currently the only supported nan_policy is 'propagate'
""")
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
@ -90,7 +90,7 @@ def invert_permutation(i: Array) -> Array:
"""Helper function that inverts a permutation array."""
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
@_wraps(scipy.stats.rankdata, lax_description="""\
@implements(scipy.stats.rankdata, lax_description="""\
Currently the only supported nan_policy is 'propagate'
""")
@partial(jit, static_argnames=["method", "axis", "nan_policy"])
@ -148,7 +148,7 @@ def rankdata(
return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_))
raise ValueError(f"unknown method '{method}'")
@_wraps(scipy.stats.sem, lax_description="""\
@implements(scipy.stats.sem, lax_description="""\
Currently the only supported nan_policies are 'propagate' and 'omit'
""")
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, xlog1py
@_wraps(osp_stats.bernoulli.logpmf, update_doc=False)
@implements(osp_stats.bernoulli.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
zero = _lax_const(k, 0)
@ -33,11 +33,11 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
-jnp.inf, log_probs)
@_wraps(osp_stats.bernoulli.pmf, update_doc=False)
@implements(osp_stats.bernoulli.pmf, update_doc=False)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.exp(logpmf(k, p, loc))
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
@implements(osp_stats.bernoulli.cdf, update_doc=False)
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
k, p = promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
@ -50,7 +50,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array:
vals = [jnp.nan, zero, one - p, one]
return jnp.select(conds, vals)
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
@implements(osp_stats.bernoulli.ppf, update_doc=False)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
q, p = promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)

View File

@ -17,12 +17,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
@_wraps(osp_stats.beta.logpdf, update_doc=False)
@implements(osp_stats.beta.logpdf, update_doc=False)
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
@ -36,13 +36,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
lax.lt(x, loc)), -jnp.inf, log_probs)
@_wraps(osp_stats.beta.pdf, update_doc=False)
@implements(osp_stats.beta.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.cdf, update_doc=False)
@implements(osp_stats.beta.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
@ -57,13 +57,13 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
)
@_wraps(osp_stats.beta.logcdf, update_doc=False)
@implements(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.sf, update_doc=False)
@implements(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
@ -78,7 +78,7 @@ def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
)
@_wraps(osp_stats.beta.logsf, update_doc=False)
@implements(osp_stats.beta.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, b, loc, scale))

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.scipy.special import betaln
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.betabinom.logpmf, update_doc=False)
@implements(osp_stats.betabinom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.betabinom.logpmf."""
@ -40,7 +40,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
return jnp.where(n_a_b_cond, jnp.nan, log_probs)
@_wraps(osp_stats.betabinom.pmf, update_doc=False)
@implements(osp_stats.betabinom.pmf, update_doc=False)
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.betabinom.pmf."""

View File

@ -17,12 +17,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
@implements(osp_stats.nbinom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.binom.logpmf."""
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
@ -36,7 +36,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf)
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
@implements(osp_stats.nbinom.pmf, update_doc=False)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.binom.pmf."""
return lax.exp(logpmf(k, n, p, loc))

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax.numpy import arctan
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
@implements(osp_stats.cauchy.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
pi = _lax_const(x, np.pi)
@ -32,13 +32,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
@implements(osp_stats.cauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
@implements(osp_stats.cauchy.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
@ -46,24 +46,24 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
@implements(osp_stats.cauchy.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, loc, scale))
@_wraps(osp_stats.cauchy.sf, update_doc=False)
@implements(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
return cdf(-x, -loc, scale)
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
@implements(osp_stats.cauchy.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
@_wraps(osp_stats.cauchy.isf, update_doc=False)
@implements(osp_stats.cauchy.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
@ -72,7 +72,7 @@ def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.add(lax.mul(unscaled, scale), loc)
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
@implements(osp_stats.cauchy.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc, gammaincc
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
@implements(osp_stats.chi2.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
@ -38,12 +38,12 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.chi2.pdf, update_doc=False)
@implements(osp_stats.chi2.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.cdf, update_doc=False)
@implements(osp_stats.chi2.cdf, update_doc=False)
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
@ -60,12 +60,12 @@ def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -
)
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
@implements(osp_stats.chi2.logcdf, update_doc=False)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.sf, update_doc=False)
@implements(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
two = _lax_const(scale, 2)
@ -82,6 +82,6 @@ def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
)
@_wraps(osp_stats.chi2.logsf, update_doc=False)
@implements(osp_stats.chi2.logsf, update_doc=False)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, df, loc, scale))

View File

@ -18,7 +18,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_dtypes_inexact, _wraps
from jax._src.numpy.util import promote_dtypes_inexact, implements
from jax.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@ -28,7 +28,7 @@ def _is_simplex(x: Array) -> Array:
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
@_wraps(osp_stats.dirichlet.logpdf, update_doc=False)
@implements(osp_stats.dirichlet.logpdf, update_doc=False)
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
return _logpdf(*promote_dtypes_inexact(x, alpha))
@ -52,6 +52,6 @@ def _logpdf(x: Array, alpha: Array) -> Array:
return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
@_wraps(osp_stats.dirichlet.pdf, update_doc=False)
@implements(osp_stats.dirichlet.pdf, update_doc=False)
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
return lax.exp(logpdf(x, alpha))

View File

@ -16,11 +16,11 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.expon.logpdf, update_doc=False)
@implements(osp_stats.expon.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
@ -28,6 +28,6 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
log_probs = lax.neg(lax.add(linear_term, log_scale))
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.expon.pdf, update_doc=False)
@implements(osp_stats.expon.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))

View File

@ -17,12 +17,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
@implements(osp_stats.gamma.logpdf, update_doc=False)
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
one = _lax_const(x, 1)
@ -32,12 +32,12 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
log_probs = lax.sub(log_linear_term, shape_terms)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.gamma.pdf, update_doc=False)
@implements(osp_stats.gamma.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.cdf, update_doc=False)
@implements(osp_stats.gamma.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
@ -50,17 +50,17 @@ def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
)
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
@implements(osp_stats.gamma.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.sf, update_doc=False)
@implements(osp_stats.gamma.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
@_wraps(osp_stats.gamma.logsf, update_doc=False)
@implements(osp_stats.gamma.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, loc, scale))

View File

@ -14,20 +14,20 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
@implements(osp_stats.gennorm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, p: ArrayLike) -> Array:
x, p = promote_args_inexact("gennorm.logpdf", x, p)
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
@implements(osp_stats.gennorm.cdf, update_doc=False)
def cdf(x: ArrayLike, p: ArrayLike) -> Array:
x, p = promote_args_inexact("gennorm.cdf", x, p)
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
@implements(osp_stats.gennorm.pdf, update_doc=False)
def pdf(x: ArrayLike, p: ArrayLike) -> Array:
return lax.exp(logpdf(x, p))

View File

@ -17,12 +17,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax.scipy.special import xlog1py
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.geom.logpmf, update_doc=False)
@implements(osp_stats.geom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
zero = _lax_const(k, 0)
@ -32,6 +32,6 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
@_wraps(osp_stats.geom.pmf, update_doc=False)
@implements(osp_stats.geom.pmf, update_doc=False)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.exp(logpmf(k, p, loc))

View File

@ -21,12 +21,12 @@ import scipy.stats as osp_stats
import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements
from jax._src.tree_util import register_pytree_node_class
from jax.scipy import linalg, special
@_wraps(osp_stats.gaussian_kde, update_doc=False)
@implements(osp_stats.gaussian_kde, update_doc=False)
@register_pytree_node_class
@dataclass(frozen=True, init=False)
class gaussian_kde:
@ -113,7 +113,7 @@ class gaussian_kde:
def n(self):
return self.dataset.shape[1]
@_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False)
@implements(osp_stats.gaussian_kde.evaluate, update_doc=False)
def evaluate(self, points):
check_arraylike("evaluate", points)
points = self._reshape_points(points)
@ -121,11 +121,11 @@ class gaussian_kde:
points.T, self.inv_cov)
return result[:, 0]
@_wraps(osp_stats.gaussian_kde.__call__, update_doc=False)
@implements(osp_stats.gaussian_kde.__call__, update_doc=False)
def __call__(self, points):
return self.evaluate(points)
@_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
@implements(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
def integrate_gaussian(self, mean, cov):
mean = jnp.atleast_1d(jnp.squeeze(mean))
cov = jnp.atleast_2d(cov)
@ -141,7 +141,7 @@ class gaussian_kde:
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
mean)
@_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
@implements(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
def integrate_box_1d(self, low, high):
if self.d != 1:
raise ValueError("integrate_box_1d() only handles 1D pdfs")
@ -153,7 +153,7 @@ class gaussian_kde:
high = jnp.squeeze((high - self.dataset) / sigma)
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
@_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
@implements(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
def integrate_kde(self, other):
if other.d != self.d:
raise ValueError("KDEs are not the same dimensionality")
@ -189,11 +189,11 @@ class gaussian_kde:
dtype=self.dataset.dtype).T
return self.dataset[:, ind] + eps
@_wraps(osp_stats.gaussian_kde.pdf, update_doc=False)
@implements(osp_stats.gaussian_kde.pdf, update_doc=False)
def pdf(self, x):
return self.evaluate(x)
@_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False)
@implements(osp_stats.gaussian_kde.logpdf, update_doc=False)
def logpdf(self, x):
check_arraylike("logpdf", x)
x = self._reshape_points(x)

View File

@ -16,11 +16,11 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.laplace.logpdf, update_doc=False)
@implements(osp_stats.laplace.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
two = _lax_const(x, 2)
@ -28,12 +28,12 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
@_wraps(osp_stats.laplace.pdf, update_doc=False)
@implements(osp_stats.laplace.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.laplace.cdf, update_doc=False)
@implements(osp_stats.laplace.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
half = _lax_const(x, 0.5)

View File

@ -18,11 +18,11 @@ from jax.scipy.special import expit, logit
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
@implements(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
@ -31,30 +31,30 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
@_wraps(osp_stats.logistic.pdf, update_doc=False)
@implements(osp_stats.logistic.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.logistic.ppf, update_doc=False)
@implements(osp_stats.logistic.ppf, update_doc=False)
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)
@_wraps(osp_stats.logistic.sf, update_doc=False)
@implements(osp_stats.logistic.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
@_wraps(osp_stats.logistic.isf, update_doc=False)
@implements(osp_stats.logistic.isf, update_doc=False)
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
@_wraps(osp_stats.logistic.cdf, update_doc=False)
@implements(osp_stats.logistic.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))

View File

@ -16,12 +16,12 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact, promote_args_numeric
from jax._src.numpy.util import implements, promote_args_inexact, promote_args_numeric
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.multinomial.logpmf, update_doc=False)
@implements(osp_stats.multinomial.logpmf, update_doc=False)
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
"""JAX implementation of scipy.stats.multinomial.logpmf."""
p, = promote_args_inexact("multinomial.logpmf", p)
@ -34,7 +34,7 @@ def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -jnp.inf)
@_wraps(osp_stats.multinomial.pmf, update_doc=False)
@implements(osp_stats.multinomial.pmf, update_doc=False)
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
"""JAX implementation of scipy.stats.multinomial.pmf."""
return lax.exp(logpmf(x, n, p))

View File

@ -19,11 +19,11 @@ import scipy.stats as osp_stats
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.util import _wraps, promote_dtypes_inexact
from jax._src.numpy.util import implements, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description="""
@implements(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description="""
In the JAX version, the `allow_singular` argument is not implemented.
""")
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
@ -50,6 +50,6 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None =
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False)
@implements(osp_stats.multivariate_normal.pdf, update_doc=False)
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
return lax.exp(logpdf(x, mean, cov))

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
@implements(osp_stats.nbinom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.nbinom.logpmf."""
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
@ -37,7 +37,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
@implements(osp_stats.nbinom.pmf, update_doc=False)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.nbinom.pmf."""
return lax.exp(logpmf(k, n, p, loc))

View File

@ -20,12 +20,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy import special
@_wraps(osp_stats.norm.logpdf, update_doc=False)
@implements(osp_stats.norm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
scale_sqrd = lax.square(scale)
@ -34,41 +34,41 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
@_wraps(osp_stats.norm.pdf, update_doc=False)
@implements(osp_stats.norm.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.norm.cdf, update_doc=False)
@implements(osp_stats.norm.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
return special.ndtr(lax.div(lax.sub(x, loc), scale))
@_wraps(osp_stats.norm.logcdf, update_doc=False)
@implements(osp_stats.norm.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
# Cast required because custom_jvp return type is broken.
return cast(Array, special.log_ndtr(lax.div(lax.sub(x, loc), scale)))
@_wraps(osp_stats.norm.ppf, update_doc=False)
@implements(osp_stats.norm.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.asarray(special.ndtri(q) * scale + loc, float)
@_wraps(osp_stats.norm.logsf, update_doc=False)
@implements(osp_stats.norm.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
@_wraps(osp_stats.norm.sf, update_doc=False)
@implements(osp_stats.norm.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
return cdf(-x, -loc, scale)
@_wraps(osp_stats.norm.isf, update_doc=False)
@implements(osp_stats.norm.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)

View File

@ -18,11 +18,11 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.pareto.logpdf, update_doc=False)
@implements(osp_stats.pareto.logpdf, update_doc=False)
def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
one = _lax_const(x, 1)
@ -31,6 +31,6 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs)
@_wraps(osp_stats.pareto.pdf, update_doc=False)
@implements(osp_stats.pareto.pdf, update_doc=False)
def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, b, loc, scale))

View File

@ -18,12 +18,12 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, gammaln, gammaincc
@_wraps(osp_stats.poisson.logpmf, update_doc=False)
@implements(osp_stats.poisson.logpmf, update_doc=False)
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
@ -31,11 +31,11 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
@_wraps(osp_stats.poisson.pmf, update_doc=False)
@implements(osp_stats.poisson.pmf, update_doc=False)
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.exp(logpmf(k, mu, loc))
@_wraps(osp_stats.poisson.cdf, update_doc=False)
@implements(osp_stats.poisson.cdf, update_doc=False)
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)

View File

@ -18,11 +18,11 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.t.logpdf, update_doc=False)
@implements(osp_stats.t.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
two = _lax_const(x, 2)
@ -37,6 +37,6 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
@_wraps(osp_stats.t.pdf, update_doc=False)
@implements(osp_stats.t.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))

View File

@ -17,7 +17,7 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
@ -69,7 +69,7 @@ def _log_gauss_mass(a, b):
return out
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
@implements(osp_stats.truncnorm.logpdf, update_doc=False)
def logpdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
@ -80,23 +80,23 @@ def logpdf(x, a, b, loc=0, scale=1):
return val
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
@implements(osp_stats.truncnorm.pdf, update_doc=False)
def pdf(x, a, b, loc=0, scale=1):
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
@implements(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
return logcdf(-x, -b, -a, -loc, scale)
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
@implements(osp_stats.truncnorm.sf, update_doc=False)
def sf(x, a, b, loc=0, scale=1):
return lax.exp(logsf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
@implements(osp_stats.truncnorm.logcdf, update_doc=False)
def logcdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
@ -113,6 +113,6 @@ def logcdf(x, a, b, loc=0, scale=1):
return logcdf
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
@implements(osp_stats.truncnorm.cdf, update_doc=False)
def cdf(x, a, b, loc=0, scale=1):
return lax.exp(logcdf(x, a, b, loc, scale))

View File

@ -19,10 +19,10 @@ from jax import lax
from jax import numpy as jnp
from jax.numpy import where, inf, logical_or
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
@_wraps(osp_stats.uniform.logpdf, update_doc=False)
@implements(osp_stats.uniform.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
@ -30,11 +30,11 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
lax.lt(x, loc)),
-inf, log_probs)
@_wraps(osp_stats.uniform.pdf, update_doc=False)
@implements(osp_stats.uniform.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.uniform.cdf, update_doc=False)
@implements(osp_stats.uniform.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
@ -43,7 +43,7 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.select(conds, vals)
@_wraps(osp_stats.uniform.ppf, update_doc=False)
@implements(osp_stats.uniform.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
return where(

View File

@ -17,15 +17,15 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.vonmises.logpdf, update_doc=False)
@implements(osp_stats.vonmises.logpdf, update_doc=False)
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
zero = _lax_const(kappa, 0)
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan)
@_wraps(osp_stats.vonmises.pdf, update_doc=False)
@implements(osp_stats.vonmises.pdf, update_doc=False)
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
return lax.exp(logpdf(x, kappa))

View File

@ -17,11 +17,11 @@ import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.numpy.util import implements, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
@implements(osp_stats.wrapcauchy.logpdf, update_doc=False)
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
return jnp.where(
@ -34,6 +34,6 @@ def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
jnp.nan,
)
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
@implements(osp_stats.wrapcauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
return lax.exp(logpdf(x, c))

View File

@ -2,7 +2,7 @@ import numpy as np
import jax.numpy as jnp
import jax.numpy.linalg as la
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, implements
def _isEmpty2d(arr):
@ -39,7 +39,7 @@ def _assert2d(*arrays):
'Array must be two-dimensional')
@_wraps(np.linalg.cond)
@implements(np.linalg.cond)
def cond(x, p=None):
check_arraylike('jnp.linalg.cond', x)
_assertNoEmpty2d(x)
@ -62,7 +62,7 @@ def cond(x, p=None):
return r
@_wraps(np.linalg.tensorinv)
@implements(np.linalg.tensorinv)
def tensorinv(a, ind=2):
check_arraylike('jnp.linalg.tensorinv', a)
a = jnp.asarray(a)
@ -79,7 +79,7 @@ def tensorinv(a, ind=2):
return ia.reshape(*invshape)
@_wraps(np.linalg.tensorsolve)
@implements(np.linalg.tensorsolve)
def tensorsolve(a, b, axes=None):
check_arraylike('jnp.linalg.tensorsolve', a, b)
a = jnp.asarray(a)
@ -108,7 +108,7 @@ def tensorsolve(a, b, axes=None):
return res
@_wraps(np.linalg.multi_dot)
@implements(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
check_arraylike('jnp.linalg.multi_dot', *arrays)
n = len(arrays)

View File

@ -4,7 +4,7 @@ import scipy.interpolate as osp_interpolate
from jax.numpy import (asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements
def _ndim_coords_from_arrays(points, ndim=None):
@ -31,7 +31,7 @@ def _ndim_coords_from_arrays(points, ndim=None):
return points
@_wraps(
@implements(
osp_interpolate.RegularGridInterpolator,
lax_description="""
In the JAX version, `bounds_error` defaults to and must always be `False` since no
@ -76,7 +76,7 @@ class RegularGridInterpolator:
self.grid = tuple(asarray(p) for p in points)
self.values = values
@_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
@implements(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
def __call__(self, xi, method=None):
method = self.method if method is None else method
if method not in ("linear", "nearest"):

View File

@ -7,7 +7,7 @@ import scipy.linalg
from jax import jit, lax
import jax.numpy as jnp
from jax._src.numpy.linalg import norm
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax._src.scipy.linalg import rsf2csf, schur
from jax._src.typing import ArrayLike, Array
@ -51,7 +51,7 @@ Additionally, unlike the SciPy implementation, when ``disp=True`` no warning
will be printed if the error in the array output is estimated to be large.
"""
@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
@implements(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
def funm(A: ArrayLike, func: Callable[[Array], Array],
disp: bool = True) -> Array | tuple[Array, Array]:
A_arr = jnp.asarray(A)

View File

@ -52,7 +52,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements
from jax._src.util import safe_zip, NumpyComplexWarning
config.parse_flags_with_absl()
@ -5861,7 +5861,7 @@ class NumpyDocTests(jtu.JaxTestCase):
if jit:
wrapped = jax.jit(wrapped)
wrapped = _wraps(orig, skip_params=['out'])(wrapped)
wrapped = implements(orig, skip_params=['out'])(wrapped)
doc = wrapped.__doc__
self.assertStartsWith(doc, "Example Docstring")