mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename _wraps to implements
This commit is contained in:
parent
4646c64f54
commit
43a9faa06a
@ -22,7 +22,7 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.numpy.util import check_arraylike, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, implements
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import ufuncs, reductions
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
@ -105,28 +105,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
|
||||
return transformed
|
||||
|
||||
|
||||
@_wraps(np.fft.fftn)
|
||||
@implements(np.fft.fftn)
|
||||
def fftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftn)
|
||||
@implements(np.fft.ifftn)
|
||||
def ifftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftn)
|
||||
@implements(np.fft.rfftn)
|
||||
def rfftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.irfftn)
|
||||
@implements(np.fft.irfftn)
|
||||
def irfftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
@ -150,31 +150,31 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType,
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft)
|
||||
@implements(np.fft.fft)
|
||||
def fft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft)
|
||||
@implements(np.fft.ifft)
|
||||
def ifft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft)
|
||||
@implements(np.fft.rfft)
|
||||
def rfft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft)
|
||||
@implements(np.fft.irfft)
|
||||
def irfft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.hfft)
|
||||
@implements(np.fft.hfft)
|
||||
def hfft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
conj_a = ufuncs.conj(a)
|
||||
@ -183,7 +183,7 @@ def hfft(a: ArrayLike, n: int | None = None,
|
||||
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis,
|
||||
norm=norm) * nn
|
||||
|
||||
@_wraps(np.fft.ihfft)
|
||||
@implements(np.fft.ihfft)
|
||||
def ihfft(a: ArrayLike, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
_axis_check_1d('ihfft', axis)
|
||||
@ -206,32 +206,32 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft2)
|
||||
@implements(np.fft.fft2)
|
||||
def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft2)
|
||||
@implements(np.fft.ifft2)
|
||||
def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft2)
|
||||
@implements(np.fft.rfft2)
|
||||
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft2)
|
||||
@implements(np.fft.irfft2)
|
||||
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
||||
norm: str | None = None) -> Array:
|
||||
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftfreq, extra_params="""
|
||||
@implements(np.fft.fftfreq, extra_params="""
|
||||
dtype : Optional
|
||||
The dtype of the returned frequencies. If not specified, JAX's default
|
||||
floating point dtype will be used.
|
||||
@ -266,7 +266,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
|
||||
return k / jnp.array(d * n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftfreq, extra_params="""
|
||||
@implements(np.fft.rfftfreq, extra_params="""
|
||||
dtype : Optional
|
||||
The dtype of the returned frequencies. If not specified, JAX's default
|
||||
floating point dtype will be used.
|
||||
@ -292,7 +292,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
|
||||
return k / jnp.array(d * n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftshift)
|
||||
@implements(np.fft.fftshift)
|
||||
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
||||
check_arraylike("fftshift", x)
|
||||
x = jnp.asarray(x)
|
||||
@ -308,7 +308,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
||||
return jnp.roll(x, shift, axes)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftshift)
|
||||
@implements(np.fft.ifftshift)
|
||||
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
||||
check_arraylike("ifftshift", x)
|
||||
x = jnp.asarray(x)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -30,7 +30,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions, ufuncs
|
||||
from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike
|
||||
from jax._src.numpy.util import implements, promote_dtypes_inexact, check_arraylike
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
|
||||
@ -63,7 +63,7 @@ def _H(x: ArrayLike) -> Array:
|
||||
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
@implements(np.linalg.cholesky)
|
||||
@jit
|
||||
def cholesky(a: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.cholesky", a)
|
||||
@ -86,7 +86,7 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Array | SVDResult: ...
|
||||
|
||||
@_wraps(np.linalg.svd)
|
||||
@implements(np.linalg.svd)
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Array | SVDResult:
|
||||
@ -115,7 +115,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False)
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_power)
|
||||
@implements(np.linalg.matrix_power)
|
||||
@partial(jit, static_argnames=('n',))
|
||||
def matrix_power(a: ArrayLike, n: int) -> Array:
|
||||
check_arraylike("jnp.linalg.matrix_power", a)
|
||||
@ -154,7 +154,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_rank)
|
||||
@implements(np.linalg.matrix_rank)
|
||||
@jit
|
||||
def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("jnp.linalg.matrix_rank", M)
|
||||
@ -211,7 +211,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
|
||||
sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
|
||||
return sign_diag * sign_taus, log_abs_det
|
||||
|
||||
@_wraps(
|
||||
@implements(
|
||||
np.linalg.slogdet,
|
||||
extra_params=textwrap.dedent("""
|
||||
method: string, optional
|
||||
@ -357,7 +357,7 @@ def _det_3x3(a: Array) -> Array:
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.det)
|
||||
@implements(np.linalg.det)
|
||||
@jit
|
||||
def det(a: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.det", a)
|
||||
@ -383,7 +383,7 @@ def _det_jvp(primals, tangents):
|
||||
return y, jnp.trace(z, axis1=-1, axis2=-2)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eig, lax_description="""
|
||||
@implements(np.linalg.eig, lax_description="""
|
||||
This differs from :func:`numpy.linalg.eig` in that the return type of
|
||||
:func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input,
|
||||
and ``complex128`` for 64-bit input.
|
||||
@ -399,7 +399,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
|
||||
return w, v
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvals)
|
||||
@implements(np.linalg.eigvals)
|
||||
@jit
|
||||
def eigvals(a: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.eigvals", a)
|
||||
@ -407,7 +407,7 @@ def eigvals(a: ArrayLike) -> Array:
|
||||
compute_right_eigenvectors=False)[0]
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigh)
|
||||
@implements(np.linalg.eigh)
|
||||
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
||||
def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
symmetrize_input: bool = True) -> EighResult:
|
||||
@ -425,7 +425,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
return EighResult(w, v)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvalsh)
|
||||
@implements(np.linalg.eigvalsh)
|
||||
@partial(jit, static_argnames=('UPLO',))
|
||||
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
check_arraylike("jnp.linalg.eigvalsh", a)
|
||||
@ -434,7 +434,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
@implements(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
@ -494,7 +494,7 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
return p, p_dot
|
||||
|
||||
|
||||
@_wraps(np.linalg.inv)
|
||||
@implements(np.linalg.inv)
|
||||
@jit
|
||||
def inv(a: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.inv", a)
|
||||
@ -506,7 +506,7 @@ def inv(a: ArrayLike) -> Array:
|
||||
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
|
||||
|
||||
|
||||
@_wraps(np.linalg.norm)
|
||||
@implements(np.linalg.norm)
|
||||
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
|
||||
def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
axis: None | tuple[int, ...] | int = None,
|
||||
@ -608,7 +608,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
|
||||
@overload
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...
|
||||
|
||||
@_wraps(np.linalg.qr)
|
||||
@implements(np.linalg.qr)
|
||||
@partial(jit, static_argnames=('mode',))
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
||||
check_arraylike("jnp.linalg.qr", a)
|
||||
@ -628,7 +628,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
||||
return QRResult(q, r)
|
||||
|
||||
|
||||
@_wraps(np.linalg.solve)
|
||||
@implements(np.linalg.solve)
|
||||
@jit
|
||||
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.solve", a, b)
|
||||
@ -689,7 +689,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
|
||||
|
||||
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
@implements(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
||||
@ -710,7 +710,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
|
||||
return _jit_lstsq(a, b, rcond)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "cross", None))
|
||||
@implements(getattr(np.linalg, "cross", None))
|
||||
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
|
||||
check_arraylike("jnp.linalg.outer", x1, x2)
|
||||
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
|
||||
@ -722,7 +722,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
|
||||
return jnp.cross(x1, x2, axis=axis)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "outer", None))
|
||||
@implements(getattr(np.linalg, "outer", None))
|
||||
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("jnp.linalg.outer", x1, x2)
|
||||
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
|
||||
@ -731,7 +731,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return x1[:, None] * x2[None, :]
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "matrix_norm", None))
|
||||
@implements(getattr(np.linalg, "matrix_norm", None))
|
||||
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
|
||||
"""
|
||||
Computes the matrix norm of a matrix (or a stack of matrices) x.
|
||||
@ -740,7 +740,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') ->
|
||||
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "matrix_transpose", None))
|
||||
@implements(getattr(np.linalg, "matrix_transpose", None))
|
||||
def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
"""Transposes a matrix (or a stack of matrices) x."""
|
||||
check_arraylike('jnp.linalg.matrix_transpose', x)
|
||||
@ -751,7 +751,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "vector_norm", None))
|
||||
@implements(getattr(np.linalg, "vector_norm", None))
|
||||
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
|
||||
ord: int | str = 2) -> Array:
|
||||
"""Computes the vector norm of a vector (or batch of vectors) x."""
|
||||
@ -764,31 +764,31 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
|
||||
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "vecdot", None))
|
||||
@implements(getattr(np.linalg, "vecdot", None))
|
||||
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
|
||||
return jnp.vecdot(x1, x2, axis=axis)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "matmul", None))
|
||||
@implements(getattr(np.linalg, "matmul", None))
|
||||
def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike('jnp.linalg.matmul', x1, x2)
|
||||
return jnp.matmul(x1, x2)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "tensordot", None))
|
||||
@implements(getattr(np.linalg, "tensordot", None))
|
||||
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
|
||||
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
|
||||
check_arraylike('jnp.linalg.tensordot', x1, x2)
|
||||
return jnp.tensordot(x1, x2, axes=axes)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "svdvals", None))
|
||||
@implements(getattr(np.linalg, "svdvals", None))
|
||||
def svdvals(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('jnp.linalg.svdvals', x)
|
||||
return svd(x, compute_uv=False, hermitian=False)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "diagonal", None))
|
||||
@implements(getattr(np.linalg, "diagonal", None))
|
||||
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
|
||||
check_arraylike('jnp.linalg.diagonal', x)
|
||||
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
|
||||
|
@ -31,7 +31,7 @@ from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
|
||||
from jax._src.numpy.reductions import all
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, _wraps)
|
||||
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, implements)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
|
||||
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
|
||||
|
||||
|
||||
@_wraps(np.roots, lax_description="""\
|
||||
@implements(np.roots, lax_description="""\
|
||||
Unlike the numpy version of this function, the JAX version returns the roots in
|
||||
a complex array regardless of the values of the roots. Additionally, the jax
|
||||
version of this function adds the ``strip_zeros`` function which must be set to
|
||||
@ -106,7 +106,7 @@ _POLYFIT_DOC = """\
|
||||
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
|
||||
Also, it works best on rcond <= 10e-3 values.
|
||||
"""
|
||||
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@implements(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
|
||||
full: bool = False, w: Array | None = None, cov: bool = False
|
||||
@ -187,7 +187,7 @@ np.poly returns an array with a real dtype in such cases.
|
||||
jax returns an array with a complex dtype in such cases.
|
||||
"""
|
||||
|
||||
@_wraps(np.poly, lax_description=_POLY_DOC)
|
||||
@implements(np.poly, lax_description=_POLY_DOC)
|
||||
@jit
|
||||
def poly(seq_of_zeros: Array) -> Array:
|
||||
check_arraylike('poly', seq_of_zeros)
|
||||
@ -214,7 +214,7 @@ def poly(seq_of_zeros: Array) -> Array:
|
||||
return a
|
||||
|
||||
|
||||
@_wraps(np.polyval, lax_description="""\
|
||||
@implements(np.polyval, lax_description="""\
|
||||
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
|
||||
have a major impact on performance for evaluating high-order polynomials. The
|
||||
parameter controls the number of unrolled steps with ``lax.scan`` inside the
|
||||
@ -231,7 +231,7 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
|
||||
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
|
||||
return y
|
||||
|
||||
@_wraps(np.polyadd)
|
||||
@implements(np.polyadd)
|
||||
@jit
|
||||
def polyadd(a1: Array, a2: Array) -> Array:
|
||||
check_arraylike("polyadd", a1, a2)
|
||||
@ -242,7 +242,7 @@ def polyadd(a1: Array, a2: Array) -> Array:
|
||||
return a2.at[-a1.shape[0]:].add(a1)
|
||||
|
||||
|
||||
@_wraps(np.polyint)
|
||||
@implements(np.polyint)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
@ -265,7 +265,7 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
|
||||
return true_divide(concatenate((p, k_arr)), coeff)
|
||||
|
||||
|
||||
@_wraps(np.polyder)
|
||||
@implements(np.polyder)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyder(p: Array, m: int = 1) -> Array:
|
||||
check_arraylike("polyder", p)
|
||||
@ -288,7 +288,7 @@ considered zero may lead to inconsistent results between NumPy and JAX, and even
|
||||
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
|
||||
"""
|
||||
|
||||
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
@implements(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
|
||||
check_arraylike("polymul", a1, a2)
|
||||
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
|
||||
@ -300,7 +300,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
|
||||
a2_arr = asarray([0], dtype=a1_arr.dtype)
|
||||
return convolve(a1_arr, a2_arr, mode='full')
|
||||
|
||||
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
@implements(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
|
||||
check_arraylike("polydiv", u, v)
|
||||
u_arr, v_arr = promote_dtypes_inexact(u, v)
|
||||
@ -317,7 +317,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
|
||||
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
|
||||
return q, u_arr
|
||||
|
||||
@_wraps(np.polysub)
|
||||
@implements(np.polysub)
|
||||
@jit
|
||||
def polysub(a1: Array, a2: Array) -> Array:
|
||||
check_arraylike("polysub", a1, a2)
|
||||
|
@ -31,7 +31,7 @@ from jax._src import dtypes
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, check_arraylike, _complex_elem_type,
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.util import (
|
||||
@ -219,7 +219,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum,
|
||||
promote_integers=promote_integers)
|
||||
|
||||
@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
|
||||
@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
|
||||
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
|
||||
@ -238,7 +238,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, promote_integers=promote_integers)
|
||||
|
||||
@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
|
||||
@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
|
||||
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
@ -256,7 +256,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@_wraps(np.max, skip_params=['out'])
|
||||
@implements(np.max, skip_params=['out'])
|
||||
def max(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
@ -271,7 +271,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@_wraps(np.min, skip_params=['out'])
|
||||
@implements(np.min, skip_params=['out'])
|
||||
def min(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
@ -284,7 +284,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.all, skip_params=['out'])
|
||||
@implements(np.all, skip_params=['out'])
|
||||
def all(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
|
||||
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
@ -296,7 +296,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.any, skip_params=['out'])
|
||||
@implements(np.any, skip_params=['out'])
|
||||
def any(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
|
||||
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
@ -316,7 +316,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
|
||||
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
|
||||
return size
|
||||
|
||||
@_wraps(np.mean, skip_params=['out'])
|
||||
@implements(np.mean, skip_params=['out'])
|
||||
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
@ -365,7 +365,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
|
||||
@overload
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
|
||||
@_wraps(np.average)
|
||||
@implements(np.average)
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
|
||||
@ -425,7 +425,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
return avg
|
||||
|
||||
|
||||
@_wraps(np.var, skip_params=['out'])
|
||||
@implements(np.var, skip_params=['out'])
|
||||
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
@ -486,7 +486,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
|
||||
return _upcast_f16(computation_dtype), np.dtype(dtype)
|
||||
|
||||
|
||||
@_wraps(np.std, skip_params=['out'])
|
||||
@implements(np.std, skip_params=['out'])
|
||||
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
@ -506,7 +506,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
@_wraps(np.ptp, skip_params=['out'])
|
||||
@implements(np.ptp, skip_params=['out'])
|
||||
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False) -> Array:
|
||||
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
|
||||
@ -522,7 +522,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
return lax.sub(x, y)
|
||||
|
||||
|
||||
@_wraps(np.count_nonzero)
|
||||
@implements(np.count_nonzero)
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def count_nonzero(a: ArrayLike, axis: Axis = None,
|
||||
keepdims: bool = False) -> Array:
|
||||
@ -546,7 +546,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
|
||||
else:
|
||||
return out
|
||||
|
||||
@_wraps(np.nanmin, skip_params=['out'])
|
||||
@implements(np.nanmin, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
@ -555,7 +555,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmax, skip_params=['out'])
|
||||
@implements(np.nanmax, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
@ -564,7 +564,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nansum, skip_params=['out'])
|
||||
@implements(np.nansum, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
@ -578,7 +578,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
|
||||
if nansum.__doc__ is not None:
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
@implements(np.nanprod, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
@ -588,7 +588,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmean, skip_params=['out'])
|
||||
@implements(np.nanmean, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
keepdims: bool = False, where: ArrayLike | None = None) -> Array:
|
||||
@ -608,7 +608,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
|
||||
return td
|
||||
|
||||
|
||||
@_wraps(np.nanvar, skip_params=['out'])
|
||||
@implements(np.nanvar, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
ddof: int = 0, keepdims: bool = False,
|
||||
@ -639,7 +639,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
|
||||
return lax.convert_element_type(result, dtype)
|
||||
|
||||
|
||||
@_wraps(np.nanstd, skip_params=['out'])
|
||||
@implements(np.nanstd, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
ddof: int = 0, keepdims: bool = False,
|
||||
@ -664,7 +664,7 @@ match the dtype of the input.
|
||||
|
||||
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
|
||||
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
|
||||
@_wraps(np_reduction, skip_params=['out'],
|
||||
@implements(np_reduction, skip_params=['out'],
|
||||
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
|
||||
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
@ -709,7 +709,7 @@ nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
fill_nan=True, fill_value=1)
|
||||
|
||||
# Quantiles
|
||||
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
|
||||
@ -725,7 +725,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
|
||||
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
|
||||
@ -862,7 +862,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
|
||||
result = result.reshape(keepdim)
|
||||
return lax.convert_element_type(result, a.dtype)
|
||||
|
||||
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def percentile(a: ArrayLike, q: ArrayLike,
|
||||
@ -874,7 +874,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
|
||||
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, method=method, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
@ -887,7 +887,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
interpolation=interpolation, method=method,
|
||||
keepdims=keepdims)
|
||||
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
@ -896,7 +896,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, method='midpoint')
|
||||
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@implements(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
|
@ -34,7 +34,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
sort, where, zeros)
|
||||
from jax._src.numpy.reductions import any, cumsum
|
||||
from jax._src.numpy.ufuncs import isnan
|
||||
from jax._src.numpy.util import check_arraylike, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, implements
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@ -61,7 +61,7 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
|
||||
else:
|
||||
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
|
||||
|
||||
@_wraps(np.setdiff1d,
|
||||
@implements(np.setdiff1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
@ -98,7 +98,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
|
||||
|
||||
|
||||
@_wraps(np.union1d,
|
||||
@implements(np.union1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``union1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
@ -125,7 +125,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
|
||||
return cast(Array, out)
|
||||
|
||||
|
||||
@_wraps(np.setxor1d, lax_description="""
|
||||
@implements(np.setxor1d, lax_description="""
|
||||
In the JAX version, the input arrays are explicitly flattened regardless
|
||||
of assume_unique value.
|
||||
""")
|
||||
@ -169,7 +169,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
|
||||
return aux, mask
|
||||
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
@implements(np.intersect1d)
|
||||
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return_indices: bool = False) -> Array | tuple[Array, Array, Array]:
|
||||
check_arraylike("intersect1d", ar1, ar2)
|
||||
@ -206,7 +206,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return int1d
|
||||
|
||||
|
||||
@_wraps(np.isin, lax_description="""
|
||||
@implements(np.isin, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
def isin(element: ArrayLike, test_elements: ArrayLike,
|
||||
@ -312,7 +312,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
ret += (mask.sum(),)
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
@_wraps(np.unique, skip_params=['axis'],
|
||||
@implements(np.unique, skip_params=['axis'],
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``unique`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
@ -368,7 +368,7 @@ class _UniqueInverseResult(NamedTuple):
|
||||
inverse_indices: Array
|
||||
|
||||
|
||||
@_wraps(getattr(np, "unique_all", None))
|
||||
@implements(getattr(np, "unique_all", None))
|
||||
def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
|
||||
check_arraylike("unique_all", x)
|
||||
values, indices, inverse_indices, counts = unique(
|
||||
@ -376,21 +376,21 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
|
||||
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
|
||||
|
||||
|
||||
@_wraps(getattr(np, "unique_counts", None))
|
||||
@implements(getattr(np, "unique_counts", None))
|
||||
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
|
||||
check_arraylike("unique_counts", x)
|
||||
values, counts = unique(x, return_counts=True, equal_nan=False)
|
||||
return _UniqueCountsResult(values=values, counts=counts)
|
||||
|
||||
|
||||
@_wraps(getattr(np, "unique_inverse", None))
|
||||
@implements(getattr(np, "unique_inverse", None))
|
||||
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
|
||||
check_arraylike("unique_inverse", x)
|
||||
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
|
||||
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
|
||||
|
||||
|
||||
@_wraps(getattr(np, "unique_values", None))
|
||||
@implements(getattr(np, "unique_values", None))
|
||||
def unique_values(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("unique_values", x)
|
||||
return cast(Array, unique(x, equal_nan=False))
|
||||
|
@ -27,7 +27,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
|
||||
from jax._src.numpy.reductions import _moveaxis
|
||||
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
|
||||
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.util import canonicalize_axis, set_module
|
||||
import numpy as np
|
||||
@ -131,7 +131,7 @@ class ufunc:
|
||||
raise NotImplementedError(f"where argument of {self}")
|
||||
return self._call(*args, **kwargs)
|
||||
|
||||
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
|
||||
@implements(np.ufunc.reduce, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
|
||||
def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
@ -219,7 +219,7 @@ class ufunc:
|
||||
result = result.reshape(final_shape)
|
||||
return result
|
||||
|
||||
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@implements(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
|
||||
def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
out: None = None) -> Array:
|
||||
@ -257,7 +257,7 @@ class ufunc:
|
||||
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
|
||||
return _moveaxis(result, 0, axis)
|
||||
|
||||
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@implements(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
|
||||
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
|
||||
inplace: bool = True) -> Array:
|
||||
@ -296,7 +296,7 @@ class ufunc:
|
||||
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
|
||||
return carry[1]
|
||||
|
||||
@_wraps(np.ufunc.reduceat, module="numpy.ufunc")
|
||||
@implements(np.ufunc.reduceat, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
|
||||
def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
@ -335,7 +335,7 @@ class ufunc:
|
||||
out)
|
||||
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
|
||||
|
||||
@_wraps(np.ufunc.outer, module="numpy.ufunc")
|
||||
@implements(np.ufunc.outer, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnums=[0])
|
||||
def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array:
|
||||
if self.nin != 2:
|
||||
|
@ -34,7 +34,7 @@ from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_args, promote_args_inexact,
|
||||
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
|
||||
promote_shapes, _where, _wraps, check_no_float0s)
|
||||
promote_shapes, _where, implements, check_no_float0s)
|
||||
|
||||
_lax_const = lax._const
|
||||
|
||||
@ -68,9 +68,9 @@ def _one_to_one_unop(
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
|
||||
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn, module='numpy')(fn)
|
||||
return implements(numpy_fn, module='numpy')(fn)
|
||||
|
||||
|
||||
def _one_to_one_binop(
|
||||
@ -87,9 +87,9 @@ def _one_to_one_binop(
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
|
||||
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn, module='numpy')(fn)
|
||||
return implements(numpy_fn, module='numpy')(fn)
|
||||
|
||||
|
||||
def _maybe_bool_binop(
|
||||
@ -102,9 +102,9 @@ def _maybe_bool_binop(
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
|
||||
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
return implements(numpy_fn, lax_description=doc, module='numpy')(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn, module='numpy')(fn)
|
||||
return implements(numpy_fn, module='numpy')(fn)
|
||||
|
||||
|
||||
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
|
||||
@ -120,7 +120,7 @@ def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
|
||||
return lax_fn(x1, x2)
|
||||
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
|
||||
fn = jit(fn, inline=True)
|
||||
return _wraps(numpy_fn, module='numpy')(fn)
|
||||
return implements(numpy_fn, module='numpy')(fn)
|
||||
|
||||
@overload
|
||||
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ...
|
||||
@ -130,7 +130,7 @@ def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ...
|
||||
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ...
|
||||
|
||||
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp:
|
||||
@_wraps(np_op, update_doc=False, module='numpy')
|
||||
@implements(np_op, update_doc=False, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def op(*args):
|
||||
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
|
||||
@ -214,14 +214,14 @@ atanh = _one_to_one_unop(getattr(np, "atanh", np.arctanh), lax.atanh, True)
|
||||
atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True)
|
||||
|
||||
|
||||
@_wraps(getattr(np, 'bitwise_count', None), module='numpy')
|
||||
@implements(getattr(np, 'bitwise_count', None), module='numpy')
|
||||
@jit
|
||||
def bitwise_count(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_numeric("bitwise_count", x)
|
||||
# Following numpy we take the absolute value and return uint8.
|
||||
return lax.population_count(abs(x)).astype('uint8')
|
||||
|
||||
@_wraps(np.right_shift, module='numpy')
|
||||
@implements(np.right_shift, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
|
||||
@ -229,7 +229,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
|
||||
return lax_fn(x1, x2)
|
||||
|
||||
@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
|
||||
@implements(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2)
|
||||
@ -237,16 +237,16 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
|
||||
return lax_fn(x1, x2)
|
||||
|
||||
@_wraps(np.absolute, module='numpy')
|
||||
@implements(np.absolute, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def absolute(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('absolute', x)
|
||||
dt = dtypes.dtype(x)
|
||||
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
|
||||
abs = _wraps(np.abs, module='numpy')(absolute)
|
||||
abs = implements(np.abs, module='numpy')(absolute)
|
||||
|
||||
|
||||
@_wraps(np.rint, module='numpy')
|
||||
@implements(np.rint, module='numpy')
|
||||
@jit
|
||||
def rint(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('rint', x)
|
||||
@ -258,7 +258,7 @@ def rint(x: ArrayLike, /) -> Array:
|
||||
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
|
||||
|
||||
|
||||
@_wraps(np.copysign, module='numpy')
|
||||
@implements(np.copysign, module='numpy')
|
||||
@jit
|
||||
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_inexact("copysign", x1, x2)
|
||||
@ -267,7 +267,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
|
||||
|
||||
|
||||
@_wraps(np.true_divide, module='numpy')
|
||||
@implements(np.true_divide, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_inexact("true_divide", x1, x2)
|
||||
@ -276,7 +276,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
divide = true_divide
|
||||
|
||||
|
||||
@_wraps(np.floor_divide, module='numpy')
|
||||
@implements(np.floor_divide, module='numpy')
|
||||
@jit
|
||||
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
|
||||
@ -301,7 +301,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return _float_divmod(x1, x2)[0]
|
||||
|
||||
|
||||
@_wraps(np.divmod, module='numpy')
|
||||
@implements(np.divmod, module='numpy')
|
||||
@jit
|
||||
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
|
||||
x1, x2 = promote_args_numeric("divmod", x1, x2)
|
||||
@ -323,7 +323,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
|
||||
return lax.round(div), mod
|
||||
|
||||
|
||||
@_wraps(np.power, module='numpy')
|
||||
@implements(np.power, module='numpy')
|
||||
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("power", x1, x2)
|
||||
check_no_float0s("power", x1, x2)
|
||||
@ -393,7 +393,7 @@ def _pow_int_int(x1, x2):
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp, module='numpy')
|
||||
@implements(np.logaddexp, module='numpy')
|
||||
@jit
|
||||
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_inexact("logaddexp", x1, x2)
|
||||
@ -431,7 +431,7 @@ def _logaddexp_jvp(primals, tangents):
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp2, module='numpy')
|
||||
@implements(np.logaddexp2, module='numpy')
|
||||
@jit
|
||||
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
|
||||
@ -459,28 +459,28 @@ def _logaddexp2_jvp(primals, tangents):
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
@_wraps(np.log2, module='numpy')
|
||||
@implements(np.log2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def log2(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_inexact("log2", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@_wraps(np.log10, module='numpy')
|
||||
@implements(np.log10, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def log10(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_inexact("log10", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
@_wraps(np.exp2, module='numpy')
|
||||
@implements(np.exp2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def exp2(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_inexact("exp2", x)
|
||||
return lax.exp2(x)
|
||||
|
||||
|
||||
@_wraps(np.signbit, module='numpy')
|
||||
@implements(np.signbit, module='numpy')
|
||||
@jit
|
||||
def signbit(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args("signbit", x)
|
||||
@ -511,7 +511,7 @@ def _normalize_float(x):
|
||||
return lax.bitcast_convert_type(x1, int_type), x2
|
||||
|
||||
|
||||
@_wraps(np.ldexp, module='numpy')
|
||||
@implements(np.ldexp, module='numpy')
|
||||
@jit
|
||||
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("ldexp", x1, x2)
|
||||
@ -560,7 +560,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
|
||||
|
||||
|
||||
@_wraps(np.frexp, module='numpy')
|
||||
@implements(np.frexp, module='numpy')
|
||||
@jit
|
||||
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
|
||||
check_arraylike("frexp", x)
|
||||
@ -584,7 +584,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
|
||||
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
|
||||
|
||||
|
||||
@_wraps(np.remainder, module='numpy')
|
||||
@implements(np.remainder, module='numpy')
|
||||
@jit
|
||||
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = promote_args_numeric("remainder", x1, x2)
|
||||
@ -596,10 +596,10 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
do_plus = lax.bitwise_and(
|
||||
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
|
||||
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
|
||||
mod = _wraps(np.mod, module='numpy')(remainder)
|
||||
mod = implements(np.mod, module='numpy')(remainder)
|
||||
|
||||
|
||||
@_wraps(np.fmod, module='numpy')
|
||||
@implements(np.fmod, module='numpy')
|
||||
@jit
|
||||
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("fmod", x1, x2)
|
||||
@ -608,7 +608,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return lax.rem(*promote_args_numeric("fmod", x1, x2))
|
||||
|
||||
|
||||
@_wraps(np.square, module='numpy')
|
||||
@implements(np.square, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def square(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("square", x)
|
||||
@ -616,14 +616,14 @@ def square(x: ArrayLike, /) -> Array:
|
||||
return lax.integer_pow(x, 2)
|
||||
|
||||
|
||||
@_wraps(np.deg2rad, module='numpy')
|
||||
@implements(np.deg2rad, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def deg2rad(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_inexact("deg2rad", x)
|
||||
return lax.mul(x, _lax_const(x, np.pi / 180))
|
||||
|
||||
|
||||
@_wraps(np.rad2deg, module='numpy')
|
||||
@implements(np.rad2deg, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def rad2deg(x: ArrayLike, /) -> Array:
|
||||
x, = promote_args_inexact("rad2deg", x)
|
||||
@ -634,7 +634,7 @@ degrees = rad2deg
|
||||
radians = deg2rad
|
||||
|
||||
|
||||
@_wraps(np.conjugate, module='numpy')
|
||||
@implements(np.conjugate, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def conjugate(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("conjugate", x)
|
||||
@ -642,20 +642,20 @@ def conjugate(x: ArrayLike, /) -> Array:
|
||||
conj = conjugate
|
||||
|
||||
|
||||
@_wraps(np.imag)
|
||||
@implements(np.imag)
|
||||
@partial(jit, inline=True)
|
||||
def imag(val: ArrayLike, /) -> Array:
|
||||
check_arraylike("imag", val)
|
||||
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
|
||||
|
||||
|
||||
@_wraps(np.real)
|
||||
@implements(np.real)
|
||||
@partial(jit, inline=True)
|
||||
def real(val: ArrayLike, /) -> Array:
|
||||
check_arraylike("real", val)
|
||||
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
|
||||
|
||||
@_wraps(np.modf, module='numpy', skip_params=['out'])
|
||||
@implements(np.modf, module='numpy', skip_params=['out'])
|
||||
@jit
|
||||
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
|
||||
check_arraylike("modf", x)
|
||||
@ -666,7 +666,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
|
||||
return x - whole, whole
|
||||
|
||||
|
||||
@_wraps(np.isfinite, module='numpy')
|
||||
@implements(np.isfinite, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def isfinite(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("isfinite", x)
|
||||
@ -679,7 +679,7 @@ def isfinite(x: ArrayLike, /) -> Array:
|
||||
return lax.full_like(x, True, dtype=np.bool_)
|
||||
|
||||
|
||||
@_wraps(np.isinf, module='numpy')
|
||||
@implements(np.isinf, module='numpy')
|
||||
@jit
|
||||
def isinf(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("isinf", x)
|
||||
@ -707,24 +707,24 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
|
||||
return lax.full_like(x, False, dtype=np.bool_)
|
||||
|
||||
|
||||
isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])(
|
||||
isposinf: UnOp = implements(np.isposinf, skip_params=['out'])(
|
||||
lambda x, /, out=None: _isposneginf(np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
|
||||
isneginf: UnOp = implements(np.isneginf, skip_params=['out'])(
|
||||
lambda x, /, out=None: _isposneginf(-np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(np.isnan, module='numpy')
|
||||
@implements(np.isnan, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def isnan(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("isnan", x)
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
@_wraps(np.heaviside, module='numpy')
|
||||
@implements(np.heaviside, module='numpy')
|
||||
@jit
|
||||
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("heaviside", x1, x2)
|
||||
@ -734,7 +734,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
|
||||
|
||||
|
||||
@_wraps(np.hypot, module='numpy')
|
||||
@implements(np.hypot, module='numpy')
|
||||
@jit
|
||||
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("hypot", x1, x2)
|
||||
@ -745,7 +745,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
|
||||
|
||||
|
||||
@_wraps(np.reciprocal, module='numpy')
|
||||
@implements(np.reciprocal, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def reciprocal(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("reciprocal", x)
|
||||
@ -753,7 +753,7 @@ def reciprocal(x: ArrayLike, /) -> Array:
|
||||
return lax.integer_pow(x, -1)
|
||||
|
||||
|
||||
@_wraps(np.sinc, update_doc=False)
|
||||
@implements(np.sinc, update_doc=False)
|
||||
@jit
|
||||
def sinc(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("sinc", x)
|
||||
|
@ -112,13 +112,13 @@ def _parse_parameters(body: str) -> dict[str, str]:
|
||||
|
||||
|
||||
def _parse_extra_params(extra_params: str) -> dict[str, str]:
|
||||
"""Parse the extra parameters passed to _wraps()"""
|
||||
"""Parse the extra parameters passed to implements()"""
|
||||
parameters = _parameter_break.split(extra_params.strip('\n'))
|
||||
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
||||
|
||||
|
||||
def _wraps(
|
||||
fun: Callable[..., Any] | None,
|
||||
def implements(
|
||||
original_fun: Callable[..., Any] | None,
|
||||
update_doc: bool = True,
|
||||
lax_description: str = "",
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
||||
@ -126,46 +126,46 @@ def _wraps(
|
||||
extra_params: str | None = None,
|
||||
module: str | None = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Specialized version of functools.wraps for wrapping numpy functions.
|
||||
"""Decorator for JAX functions which implement a specified NumPy function.
|
||||
|
||||
This produces a wrapped function with a modified docstring. In particular, if
|
||||
`update_doc` is True, parameters listed in the wrapped function that are not
|
||||
supported by the decorated function will be removed from the docstring. For
|
||||
this reason, it is important that parameter names match those in the original
|
||||
numpy function.
|
||||
This mainly contains logic to copy and modify the docstring of the original
|
||||
function. In particular, if `update_doc` is True, parameters listed in the
|
||||
original function that are not supported by the decorated function will
|
||||
be removed from the docstring. For this reason, it is important that parameter
|
||||
names match those in the original numpy function.
|
||||
|
||||
Args:
|
||||
fun: The function being wrapped
|
||||
original_fun: The original function being implemented
|
||||
update_doc: whether to transform the numpy docstring to remove references of
|
||||
parameters that are supported by the numpy version but not the JAX version.
|
||||
If False, include the numpy docstring verbatim.
|
||||
lax_description: a string description that will be added to the beginning of
|
||||
the docstring.
|
||||
sections: a list of sections to include in the docstring. The default is
|
||||
["Parameters", "returns", "References"]
|
||||
["Parameters", "Returns", "References"]
|
||||
skip_params: a list of strings containing names of parameters accepted by the
|
||||
function that should be skipped in the parameter list.
|
||||
extra_params: an optional string containing additional parameter descriptions.
|
||||
When ``update_doc=True``, these will be added to the list of parameter
|
||||
descriptions in the updated doc.
|
||||
module: an optional string specifying the module from which the wrapped function
|
||||
module: an optional string specifying the module from which the original function
|
||||
is imported. This is useful for objects such as ufuncs, where the module cannot
|
||||
be determined from the wrapped function itself.
|
||||
be determined from the original function itself.
|
||||
"""
|
||||
def wrap(op):
|
||||
op.__np_wrapped__ = fun
|
||||
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
|
||||
if fun is None:
|
||||
def decorator(wrapped_fun):
|
||||
wrapped_fun.__np_wrapped__ = original_fun
|
||||
# Allows this pattern: @implements(getattr(np, 'new_function', None))
|
||||
if original_fun is None:
|
||||
if lax_description:
|
||||
op.__doc__ = lax_description
|
||||
return op
|
||||
docstr = getattr(fun, "__doc__", None)
|
||||
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))
|
||||
wrapped_fun.__doc__ = lax_description
|
||||
return wrapped_fun
|
||||
docstr = getattr(original_fun, "__doc__", None)
|
||||
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
|
||||
try:
|
||||
mod = module or fun.__module__
|
||||
mod = module or original_fun.__module__
|
||||
except AttributeError:
|
||||
if config.enable_checks.value:
|
||||
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
|
||||
raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().")
|
||||
else:
|
||||
name = f"{mod}.{name}"
|
||||
if docstr:
|
||||
@ -173,7 +173,7 @@ def _wraps(
|
||||
parsed = _parse_numpydoc(docstr)
|
||||
|
||||
if update_doc and 'Parameters' in parsed.sections:
|
||||
code = getattr(getattr(op, "__wrapped__", op), "__code__", None)
|
||||
code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None)
|
||||
# Remove unrecognized parameter descriptions.
|
||||
parameters = _parse_parameters(parsed.sections['Parameters'])
|
||||
if extra_params:
|
||||
@ -211,18 +211,18 @@ def _wraps(
|
||||
except:
|
||||
if config.enable_checks.value:
|
||||
raise
|
||||
docstr = fun.__doc__
|
||||
docstr = original_fun.__doc__
|
||||
|
||||
op.__doc__ = docstr
|
||||
wrapped_fun.__doc__ = docstr
|
||||
for attr in ['__name__', '__qualname__']:
|
||||
try:
|
||||
value = getattr(fun, attr)
|
||||
value = getattr(original_fun, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
setattr(op, attr, value)
|
||||
return op
|
||||
return wrap
|
||||
setattr(wrapped_fun, attr, value)
|
||||
return wrapped_fun
|
||||
return decorator
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
|
@ -19,7 +19,7 @@ import textwrap
|
||||
|
||||
from jax import vmap
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.numpy.util import implements, check_arraylike, promote_dtypes_inexact
|
||||
|
||||
|
||||
_no_chkfinite_doc = textwrap.dedent("""
|
||||
@ -28,7 +28,7 @@ because compiled JAX code cannot perform checks of array values at runtime
|
||||
""")
|
||||
|
||||
|
||||
@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
|
||||
@implements(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
|
||||
def vq(obs, code_book, check_finite=True):
|
||||
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
|
||||
if obs.ndim != code_book.ndim:
|
||||
|
@ -22,7 +22,7 @@ import scipy.fft as osp_fft
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.numpy.util import _wraps, promote_dtypes_complex
|
||||
from jax._src.numpy.util import implements, promote_dtypes_complex
|
||||
from jax._src.typing import Array
|
||||
|
||||
def _W4(N: int, k: Array) -> Array:
|
||||
@ -42,7 +42,7 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
|
||||
# Implementation based on
|
||||
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
|
||||
|
||||
@_wraps(osp_fft.dct)
|
||||
@implements(osp_fft.dct)
|
||||
def dct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
if type != 2:
|
||||
@ -81,7 +81,7 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(osp_fft.dctn)
|
||||
@implements(osp_fft.dctn)
|
||||
def dctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
@ -109,7 +109,7 @@ def dctn(x: Array, type: int = 2,
|
||||
return x
|
||||
|
||||
|
||||
@_wraps(osp_fft.dct)
|
||||
@implements(osp_fft.dct)
|
||||
def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
if type != 2:
|
||||
@ -139,7 +139,7 @@ def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
out = _dct_deinterleave(x.real, axis)
|
||||
return out
|
||||
|
||||
@_wraps(osp_fft.idctn)
|
||||
@implements(osp_fft.idctn)
|
||||
def idctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
|
@ -23,7 +23,7 @@ from jax._src.numpy import util
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import jax.numpy as jnp
|
||||
|
||||
@util._wraps(scipy.integrate.trapezoid)
|
||||
@util.implements(scipy.integrate.trapezoid)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
|
@ -29,7 +29,7 @@ from jax._src import dtypes
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import qdwh
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, _wraps, promote_dtypes, promote_dtypes_inexact,
|
||||
check_arraylike, implements, promote_dtypes, promote_dtypes_inexact,
|
||||
promote_dtypes_complex)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@ -46,14 +46,14 @@ def _cholesky(a: ArrayLike, lower: bool) -> Array:
|
||||
l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False)
|
||||
return l if lower else jnp.conj(l.mT)
|
||||
|
||||
@_wraps(scipy.linalg.cholesky,
|
||||
@implements(scipy.linalg.cholesky,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Array:
|
||||
del overwrite_a, check_finite # Unused
|
||||
return _cholesky(a, lower)
|
||||
|
||||
@_wraps(scipy.linalg.cho_factor,
|
||||
@implements(scipy.linalg.cho_factor,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> tuple[Array, bool]:
|
||||
@ -70,7 +70,7 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
|
||||
transpose_a=lower, conjugate_a=lower)
|
||||
return b
|
||||
|
||||
@_wraps(scipy.linalg.cho_solve, update_doc=False,
|
||||
@implements(scipy.linalg.cho_solve, update_doc=False,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
|
||||
def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike,
|
||||
overwrite_b: bool = False, check_finite: bool = True) -> Array:
|
||||
@ -112,7 +112,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
overwrite_a: bool = False, check_finite: bool = True,
|
||||
lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.svd,
|
||||
@implements(scipy.linalg.svd,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
overwrite_a: bool = False, check_finite: bool = True,
|
||||
@ -120,7 +120,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
del overwrite_a, check_finite, lapack_driver # unused
|
||||
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
@_wraps(scipy.linalg.det,
|
||||
@implements(scipy.linalg.det,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
|
||||
del overwrite_a, check_finite # unused
|
||||
@ -182,7 +182,7 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
|
||||
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
|
||||
type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.eigh,
|
||||
@implements(scipy.linalg.eigh,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc,
|
||||
skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite'))
|
||||
def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True,
|
||||
@ -198,21 +198,21 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]:
|
||||
a = a.astype(dtypes.to_complex_dtype(a.dtype))
|
||||
return lax_linalg.schur(a)
|
||||
|
||||
@_wraps(scipy.linalg.schur)
|
||||
@implements(scipy.linalg.schur)
|
||||
def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
|
||||
if output not in ('real', 'complex'):
|
||||
raise ValueError(
|
||||
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
|
||||
return _schur(a, output)
|
||||
|
||||
@_wraps(scipy.linalg.inv,
|
||||
@implements(scipy.linalg.inv,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
|
||||
del overwrite_a, check_finite # unused
|
||||
return jnp.linalg.inv(a)
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.lu_factor,
|
||||
@implements(scipy.linalg.lu_factor,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
|
||||
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
@ -222,7 +222,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True
|
||||
return lu, pivots
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.lu_solve,
|
||||
@implements(scipy.linalg.lu_solve,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
|
||||
@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite'))
|
||||
def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
|
||||
@ -269,7 +269,7 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False,
|
||||
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.lu, update_doc=False,
|
||||
@implements(scipy.linalg.lu, update_doc=False,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
|
||||
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
|
||||
@ -320,7 +320,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
|
||||
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.qr,
|
||||
@implements(scipy.linalg.qr,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork'))
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
|
||||
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]:
|
||||
@ -352,7 +352,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
|
||||
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.solve,
|
||||
@implements(scipy.linalg.solve,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc,
|
||||
skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
|
||||
@ -391,7 +391,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str,
|
||||
else:
|
||||
return out
|
||||
|
||||
@_wraps(scipy.linalg.solve_triangular,
|
||||
@implements(scipy.linalg.solve_triangular,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite'))
|
||||
def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False,
|
||||
unit_diagonal: bool = False, overwrite_b: bool = False,
|
||||
@ -414,7 +414,7 @@ where norm() denotes the L1 norm, and
|
||||
- c=1.97 for float32 or complex64
|
||||
""")
|
||||
|
||||
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
|
||||
@implements(scipy.linalg.expm, lax_description=_expm_description)
|
||||
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
|
||||
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
|
||||
A, = promote_dtypes_inexact(A)
|
||||
@ -572,7 +572,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
|
||||
compute_expm: bool = True) -> Array | tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
|
||||
@implements(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
|
||||
@partial(jit, static_argnames=('method', 'compute_expm'))
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
|
||||
compute_expm: bool = True) -> Array | tuple[Array, Array]:
|
||||
@ -597,7 +597,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None,
|
||||
return expm_frechet_AE
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.block_diag)
|
||||
@implements(scipy.linalg.block_diag)
|
||||
@jit
|
||||
def block_diag(*arrs: ArrayLike) -> Array:
|
||||
if len(arrs) == 0:
|
||||
@ -619,7 +619,7 @@ def block_diag(*arrs: ArrayLike) -> Array:
|
||||
return acc
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.eigh_tridiagonal)
|
||||
@implements(scipy.linalg.eigh_tridiagonal)
|
||||
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
|
||||
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
|
||||
select: str = 'a', select_range: tuple[float, float] | None = None,
|
||||
@ -901,7 +901,7 @@ def _sqrtm(A: ArrayLike) -> Array:
|
||||
return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST),
|
||||
jnp.conj(Z.T), precision=lax.Precision.HIGHEST)
|
||||
|
||||
@_wraps(scipy.linalg.sqrtm,
|
||||
@implements(scipy.linalg.sqrtm,
|
||||
lax_description="""
|
||||
This differs from ``scipy.linalg.sqrtm`` in that the return type of
|
||||
``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input,
|
||||
@ -918,7 +918,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array:
|
||||
raise NotImplementedError("Blocked version is not implemented yet.")
|
||||
return _sqrtm(A)
|
||||
|
||||
@_wraps(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
|
||||
@implements(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
|
||||
@partial(jit, static_argnames=('check_finite',))
|
||||
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
del check_finite # unused
|
||||
@ -987,7 +987,7 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals
|
||||
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
|
||||
@implements(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
|
||||
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
|
||||
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Array | tuple[Array, Array]:
|
||||
@ -1010,7 +1010,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
else:
|
||||
return h
|
||||
|
||||
@_wraps(scipy.linalg.toeplitz)
|
||||
@implements(scipy.linalg.toeplitz)
|
||||
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||
if r is None:
|
||||
check_arraylike("toeplitz", c)
|
||||
|
@ -25,7 +25,7 @@ from jax._src import api
|
||||
from jax._src import util
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import implements
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import safe_zip as zip
|
||||
|
||||
@ -127,7 +127,7 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
|
||||
return result.astype(input_arr.dtype)
|
||||
|
||||
|
||||
@_wraps(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
|
||||
@implements(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
|
||||
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
|
||||
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
|
||||
Note that interpolation near boundaries differs from the scipy function,
|
||||
|
@ -34,13 +34,13 @@ from jax._src import dtypes
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, _wraps, promote_dtypes_inexact, promote_dtypes_complex)
|
||||
check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex)
|
||||
from jax._src.third_party.scipy import signal_helper
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
|
||||
|
||||
|
||||
@_wraps(osp_signal.fftconvolve)
|
||||
@implements(osp_signal.fftconvolve)
|
||||
def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full",
|
||||
axes: Sequence[int] | None = None) -> Array:
|
||||
check_arraylike('fftconvolve', in1, in2)
|
||||
@ -133,7 +133,7 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike)
|
||||
return result[0, 0]
|
||||
|
||||
|
||||
@_wraps(osp_signal.convolve)
|
||||
@implements(osp_signal.convolve)
|
||||
def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
if method == 'fft':
|
||||
@ -144,7 +144,7 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.")
|
||||
|
||||
|
||||
@_wraps(osp_signal.convolve2d)
|
||||
@implements(osp_signal.convolve2d)
|
||||
def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
|
||||
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
|
||||
if boundary != 'fill' or fillvalue != 0:
|
||||
@ -154,13 +154,13 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill
|
||||
return _convolve_nd(in1, in2, mode, precision=precision)
|
||||
|
||||
|
||||
@_wraps(osp_signal.correlate)
|
||||
@implements(osp_signal.correlate)
|
||||
def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method)
|
||||
|
||||
|
||||
@_wraps(osp_signal.correlate2d)
|
||||
@implements(osp_signal.correlate2d)
|
||||
def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
|
||||
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
|
||||
if boundary != 'fill' or fillvalue != 0:
|
||||
@ -191,7 +191,7 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(osp_signal.detrend)
|
||||
@implements(osp_signal.detrend)
|
||||
def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0,
|
||||
overwrite_data: None = None) -> Array:
|
||||
if overwrite_data is not None:
|
||||
@ -499,7 +499,7 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
|
||||
return freqs, time, result
|
||||
|
||||
|
||||
@_wraps(osp_signal.stft)
|
||||
@implements(osp_signal.stft)
|
||||
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
|
||||
noverlap: int | None = None, nfft: int | None = None,
|
||||
detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros',
|
||||
@ -518,7 +518,7 @@ to follow the latter behavior. For using the former behavior, call this
|
||||
function as `csd(x, None)`."""
|
||||
|
||||
|
||||
@_wraps(osp_signal.csd, lax_description=_csd_description)
|
||||
@implements(osp_signal.csd, lax_description=_csd_description)
|
||||
def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, detrend: str = 'constant',
|
||||
@ -551,7 +551,7 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann'
|
||||
return freqs, Pxy
|
||||
|
||||
|
||||
@_wraps(osp_signal.welch)
|
||||
@implements(osp_signal.welch)
|
||||
def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, detrend: str = 'constant',
|
||||
@ -613,7 +613,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
|
||||
return x.reshape(tuple(batch_shape) + (-1,))
|
||||
|
||||
|
||||
@_wraps(osp_signal.istft)
|
||||
@implements(osp_signal.istft)
|
||||
def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, input_onesided: bool = True,
|
||||
|
@ -22,10 +22,10 @@ import scipy.spatial.transform
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import implements
|
||||
|
||||
|
||||
@_wraps(scipy.spatial.transform.Rotation)
|
||||
@implements(scipy.spatial.transform.Rotation)
|
||||
class Rotation(typing.NamedTuple):
|
||||
"""Rotation in 3 dimensions."""
|
||||
|
||||
@ -169,7 +169,7 @@ class Rotation(typing.NamedTuple):
|
||||
return self.quat.ndim == 1
|
||||
|
||||
|
||||
@_wraps(scipy.spatial.transform.Slerp)
|
||||
@implements(scipy.spatial.transform.Slerp)
|
||||
class Slerp(typing.NamedTuple):
|
||||
"""Spherical Linear Interpolation of Rotations."""
|
||||
|
||||
|
@ -32,19 +32,19 @@ from jax._src import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import implements
|
||||
from jax._src.ops import special as ops_special
|
||||
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln, module='scipy.special')
|
||||
@implements(osp_special.gammaln, module='scipy.special')
|
||||
def gammaln(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("gammaln", x)
|
||||
return lax.lgamma(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.gamma, module='scipy.special', lax_description="""\
|
||||
@implements(osp_special.gamma, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts real-valued inputs.""")
|
||||
def gamma(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("gamma", x)
|
||||
@ -53,14 +53,14 @@ def gamma(x: ArrayLike) -> Array:
|
||||
sign = jnp.where((x > 0) | (x == floor_x), 1.0, (-1.0) ** floor_x)
|
||||
return sign * lax.exp(lax.lgamma(x))
|
||||
|
||||
betaln = _wraps(
|
||||
betaln = implements(
|
||||
osp_special.betaln,
|
||||
module='scipy.special',
|
||||
update_doc=False
|
||||
)(_betaln_impl)
|
||||
|
||||
|
||||
@_wraps(osp_special.factorial, module='scipy.special')
|
||||
@implements(osp_special.factorial, module='scipy.special')
|
||||
def factorial(n: ArrayLike, exact: bool = False) -> Array:
|
||||
if exact:
|
||||
raise NotImplementedError("factorial with exact=True")
|
||||
@ -68,58 +68,58 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array:
|
||||
return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1)))
|
||||
|
||||
|
||||
@_wraps(osp_special.beta, module='scipy.special')
|
||||
@implements(osp_special.beta, module='scipy.special')
|
||||
def beta(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
x, y = promote_args_inexact("beta", x, y)
|
||||
return lax.exp(betaln(x, y))
|
||||
|
||||
|
||||
@_wraps(osp_special.betainc, module='scipy.special')
|
||||
@implements(osp_special.betainc, module='scipy.special')
|
||||
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
|
||||
a, b, x = promote_args_inexact("betainc", a, b, x)
|
||||
return lax.betainc(a, b, x)
|
||||
|
||||
|
||||
@_wraps(osp_special.digamma, module='scipy.special', lax_description="""\
|
||||
@implements(osp_special.digamma, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts real-valued inputs.""")
|
||||
def digamma(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("digamma", x)
|
||||
return lax.digamma(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.gammainc, module='scipy.special', update_doc=False)
|
||||
def gammainc(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
a, x = promote_args_inexact("gammainc", a, x)
|
||||
return lax.igamma(a, x)
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaincc, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.gammaincc, module='scipy.special', update_doc=False)
|
||||
def gammaincc(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
a, x = promote_args_inexact("gammaincc", a, x)
|
||||
return lax.igammac(a, x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"],
|
||||
@implements(osp_special.erf, module='scipy.special', skip_params=["out"],
|
||||
lax_description="Note that the JAX version does not support complex inputs.")
|
||||
def erf(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("erf", x)
|
||||
return lax.erf(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erfc, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.erfc, module='scipy.special', update_doc=False)
|
||||
def erfc(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("erfc", x)
|
||||
return lax.erfc(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erfinv, module='scipy.special')
|
||||
@implements(osp_special.erfinv, module='scipy.special')
|
||||
def erfinv(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("erfinv", x)
|
||||
return lax.erf_inv(x)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.logit, module='scipy.special', update_doc=False)
|
||||
def logit(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("logit", x)
|
||||
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
|
||||
@ -127,17 +127,17 @@ logit.defjvps(
|
||||
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
|
||||
|
||||
|
||||
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.expit, module='scipy.special', update_doc=False)
|
||||
def expit(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("expit", x)
|
||||
return lax.logistic(x)
|
||||
|
||||
|
||||
logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
|
||||
logsumexp = implements(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.xlogy, module='scipy.special')
|
||||
@implements(osp_special.xlogy, module='scipy.special')
|
||||
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
# Note: xlogy(0, 0) should return 0 according to the function documentation.
|
||||
x, y = promote_args_inexact("xlogy", x, y)
|
||||
@ -153,7 +153,7 @@ xlogy.defjvp(_xlogy_jvp)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.xlog1py, module='scipy.special', update_doc=False)
|
||||
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
|
||||
x, y = promote_args_inexact("xlog1py", x, y)
|
||||
@ -179,14 +179,14 @@ def _xlogx_jvp(primals, tangents):
|
||||
_xlogx.defjvp(_xlogx_jvp)
|
||||
|
||||
|
||||
@_wraps(osp_special.entr, module='scipy.special')
|
||||
@implements(osp_special.entr, module='scipy.special')
|
||||
def entr(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("entr", x)
|
||||
return lax.select(lax.lt(x, _lax_const(x, 0)),
|
||||
lax.full_like(x, -np.inf),
|
||||
lax.neg(_xlogx(x)))
|
||||
|
||||
@_wraps(osp_special.multigammaln, update_doc=False)
|
||||
@implements(osp_special.multigammaln, update_doc=False)
|
||||
def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
|
||||
d = core.concrete_or_error(int, d, "d argument of multigammaln")
|
||||
a, d_ = promote_args_inexact("multigammaln", a, d)
|
||||
@ -201,7 +201,7 @@ def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
|
||||
return res + constant
|
||||
|
||||
|
||||
@_wraps(osp_special.kl_div, module="scipy.special")
|
||||
@implements(osp_special.kl_div, module="scipy.special")
|
||||
def kl_div(
|
||||
p: ArrayLike,
|
||||
q: ArrayLike,
|
||||
@ -227,7 +227,7 @@ def kl_div(
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(osp_special.rel_entr, module="scipy.special")
|
||||
@implements(osp_special.rel_entr, module="scipy.special")
|
||||
def rel_entr(
|
||||
p: ArrayLike,
|
||||
q: ArrayLike,
|
||||
@ -268,7 +268,7 @@ _BERNOULLI_COEFS = [
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.zeta, module='scipy.special')
|
||||
@implements(osp_special.zeta, module='scipy.special')
|
||||
def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array:
|
||||
if q is None:
|
||||
raise NotImplementedError(
|
||||
@ -311,7 +311,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
|
||||
zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False)
|
||||
@implements(osp_special.polygamma, module='scipy.special', update_doc=False)
|
||||
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
assert jnp.issubdtype(lax.dtype(n), jnp.integer)
|
||||
n_arr, x_arr = promote_args_inexact("polygamma", n, x)
|
||||
@ -725,22 +725,22 @@ def _norm_logpdf(x):
|
||||
log_normalizer = _lax_const(x, _norm_logpdf_constant)
|
||||
return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
|
||||
|
||||
@_wraps(osp_special.i0e, module='scipy.special')
|
||||
@implements(osp_special.i0e, module='scipy.special')
|
||||
def i0e(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("i0e", x)
|
||||
return lax.bessel_i0e(x)
|
||||
|
||||
@_wraps(osp_special.i0, module='scipy.special')
|
||||
@implements(osp_special.i0, module='scipy.special')
|
||||
def i0(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("i0", x)
|
||||
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
|
||||
|
||||
@_wraps(osp_special.i1e, module='scipy.special')
|
||||
@implements(osp_special.i1e, module='scipy.special')
|
||||
def i1e(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("i1e", x)
|
||||
return lax.bessel_i1e(x)
|
||||
|
||||
@_wraps(osp_special.i1, module='scipy.special')
|
||||
@implements(osp_special.i1, module='scipy.special')
|
||||
def i1(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("i1", x)
|
||||
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
|
||||
@ -1459,7 +1459,7 @@ def _expi_neg(x: Array) -> Array:
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@jit
|
||||
@_wraps(osp_special.expi, module='scipy.special')
|
||||
@implements(osp_special.expi, module='scipy.special')
|
||||
def expi(x: ArrayLike) -> Array:
|
||||
x_arr, = promote_args_inexact("expi", x)
|
||||
return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos])
|
||||
@ -1577,7 +1577,7 @@ def _expn3(n: int, x: Array) -> Array:
|
||||
|
||||
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,))
|
||||
@jnp.vectorize
|
||||
@_wraps(osp_special.expn, module='scipy.special')
|
||||
@implements(osp_special.expn, module='scipy.special')
|
||||
@jit
|
||||
def expn(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
n, x = promote_args_inexact("expn", n, x)
|
||||
@ -1615,7 +1615,7 @@ def expn_jvp(n, primals, tangents):
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_special.exp1, module="scipy.special")
|
||||
@implements(osp_special.exp1, module="scipy.special")
|
||||
def exp1(x: ArrayLike, module='scipy.special') -> Array:
|
||||
x, = promote_args_inexact("exp1", x)
|
||||
# Casting because custom_jvp generic does not work correctly with mypy.
|
||||
@ -1716,7 +1716,7 @@ def spence(x: Array) -> Array:
|
||||
return _spence(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.bernoulli, module='scipy.special')
|
||||
@implements(osp_special.bernoulli, module='scipy.special')
|
||||
def bernoulli(n: int) -> Array:
|
||||
# Generate Bernoulli numbers using the Chowla and Hartung algorithm.
|
||||
n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli")
|
||||
@ -1734,7 +1734,7 @@ def bernoulli(n: int) -> Array:
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.poch, module='scipy.special', lax_description="""\
|
||||
@implements(osp_special.poch, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts positive and real inputs.""")
|
||||
def poch(z: ArrayLike, m: ArrayLike) -> Array:
|
||||
# Factorial definition when m is close to an integer, otherwise gamma definition.
|
||||
@ -1883,7 +1883,7 @@ def _hyp1f1_x_derivative(a, b, x):
|
||||
@custom_derivatives.custom_jvp
|
||||
@jit
|
||||
@jnp.vectorize
|
||||
@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\
|
||||
@implements(osp_special.hyp1f1, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts positive and real inputs. Values of a, b and x
|
||||
leading to high values of 1F1 might be erroneous, considering enabling double
|
||||
precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""")
|
||||
|
@ -23,7 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import vmap
|
||||
from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import check_arraylike, implements, promote_args_inexact
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
@ -31,7 +31,7 @@ import scipy
|
||||
|
||||
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
|
||||
|
||||
@_wraps(scipy.stats.mode, lax_description="""\
|
||||
@implements(scipy.stats.mode, lax_description="""\
|
||||
Currently the only supported nan_policy is 'propagate'
|
||||
""")
|
||||
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
@ -90,7 +90,7 @@ def invert_permutation(i: Array) -> Array:
|
||||
"""Helper function that inverts a permutation array."""
|
||||
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
|
||||
|
||||
@_wraps(scipy.stats.rankdata, lax_description="""\
|
||||
@implements(scipy.stats.rankdata, lax_description="""\
|
||||
Currently the only supported nan_policy is 'propagate'
|
||||
""")
|
||||
@partial(jit, static_argnames=["method", "axis", "nan_policy"])
|
||||
@ -148,7 +148,7 @@ def rankdata(
|
||||
return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_))
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
@_wraps(scipy.stats.sem, lax_description="""\
|
||||
@implements(scipy.stats.sem, lax_description="""\
|
||||
Currently the only supported nan_policies are 'propagate' and 'omit'
|
||||
""")
|
||||
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import xlogy, xlog1py
|
||||
|
||||
|
||||
@_wraps(osp_stats.bernoulli.logpmf, update_doc=False)
|
||||
@implements(osp_stats.bernoulli.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
@ -33,11 +33,11 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
|
||||
-jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.bernoulli.pmf, update_doc=False)
|
||||
@implements(osp_stats.bernoulli.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
|
||||
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
|
||||
@implements(osp_stats.bernoulli.cdf, update_doc=False)
|
||||
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
|
||||
k, p = promote_args_inexact('bernoulli.cdf', k, p)
|
||||
zero, one = _lax_const(k, 0), _lax_const(k, 1)
|
||||
@ -50,7 +50,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array:
|
||||
vals = [jnp.nan, zero, one - p, one]
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
|
||||
@implements(osp_stats.bernoulli.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
|
||||
q, p = promote_args_inexact('bernoulli.ppf', q, p)
|
||||
zero, one = _lax_const(q, 0), _lax_const(q, 1)
|
||||
|
@ -17,12 +17,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logpdf, update_doc=False)
|
||||
@implements(osp_stats.beta.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
|
||||
@ -36,13 +36,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
lax.lt(x, loc)), -jnp.inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.pdf, update_doc=False)
|
||||
@implements(osp_stats.beta.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.cdf, update_doc=False)
|
||||
@implements(osp_stats.beta.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
|
||||
@ -57,13 +57,13 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
||||
@implements(osp_stats.beta.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.sf, update_doc=False)
|
||||
@implements(osp_stats.beta.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
|
||||
@ -78,7 +78,7 @@ def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logsf, update_doc=False)
|
||||
@implements(osp_stats.beta.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, a, b, loc, scale))
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.scipy.special import betaln
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.betabinom.logpmf, update_doc=False)
|
||||
@implements(osp_stats.betabinom.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.betabinom.logpmf."""
|
||||
@ -40,7 +40,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
return jnp.where(n_a_b_cond, jnp.nan, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.betabinom.pmf, update_doc=False)
|
||||
@implements(osp_stats.betabinom.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.betabinom.pmf."""
|
||||
|
@ -17,12 +17,12 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
|
||||
@implements(osp_stats.nbinom.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.binom.logpmf."""
|
||||
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
|
||||
@ -36,7 +36,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
|
||||
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf)
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
|
||||
@implements(osp_stats.nbinom.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.binom.pmf."""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax.numpy import arctan
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
@ -32,13 +32,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
@ -46,24 +46,24 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
@ -72,7 +72,7 @@ def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
|
||||
@implements(osp_stats.cauchy.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammainc, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||
@implements(osp_stats.chi2.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
@ -38,12 +38,12 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.chi2.pdf, update_doc=False)
|
||||
@implements(osp_stats.chi2.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.cdf, update_doc=False)
|
||||
@implements(osp_stats.chi2.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
@ -60,12 +60,12 @@ def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
|
||||
@implements(osp_stats.chi2.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
||||
@implements(osp_stats.chi2.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
@ -82,6 +82,6 @@ def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logsf, update_doc=False)
|
||||
@implements(osp_stats.chi2.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, df, loc, scale))
|
||||
|
@ -18,7 +18,7 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_dtypes_inexact, _wraps
|
||||
from jax._src.numpy.util import promote_dtypes_inexact, implements
|
||||
from jax.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@ -28,7 +28,7 @@ def _is_simplex(x: Array) -> Array:
|
||||
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
|
||||
|
||||
|
||||
@_wraps(osp_stats.dirichlet.logpdf, update_doc=False)
|
||||
@implements(osp_stats.dirichlet.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
return _logpdf(*promote_dtypes_inexact(x, alpha))
|
||||
|
||||
@ -52,6 +52,6 @@ def _logpdf(x: Array, alpha: Array) -> Array:
|
||||
return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
|
||||
|
||||
|
||||
@_wraps(osp_stats.dirichlet.pdf, update_doc=False)
|
||||
@implements(osp_stats.dirichlet.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, alpha))
|
||||
|
@ -16,11 +16,11 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.expon.logpdf, update_doc=False)
|
||||
@implements(osp_stats.expon.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||
log_scale = lax.log(scale)
|
||||
@ -28,6 +28,6 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
log_probs = lax.neg(lax.add(linear_term, log_scale))
|
||||
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.expon.pdf, update_doc=False)
|
||||
@implements(osp_stats.expon.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
@ -17,12 +17,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
|
||||
@implements(osp_stats.gamma.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
@ -32,12 +32,12 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
||||
log_probs = lax.sub(log_linear_term, shape_terms)
|
||||
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.gamma.pdf, update_doc=False)
|
||||
@implements(osp_stats.gamma.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.cdf, update_doc=False)
|
||||
@implements(osp_stats.gamma.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
|
||||
return gammainc(
|
||||
@ -50,17 +50,17 @@ def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
|
||||
@implements(osp_stats.gamma.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.sf, update_doc=False)
|
||||
@implements(osp_stats.gamma.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logsf, update_doc=False)
|
||||
@implements(osp_stats.gamma.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, a, loc, scale))
|
||||
|
@ -14,20 +14,20 @@
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
|
||||
@implements(osp_stats.gennorm.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, p: ArrayLike) -> Array:
|
||||
x, p = promote_args_inexact("gennorm.logpdf", x, p)
|
||||
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
|
||||
|
||||
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
|
||||
@implements(osp_stats.gennorm.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, p: ArrayLike) -> Array:
|
||||
x, p = promote_args_inexact("gennorm.cdf", x, p)
|
||||
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
|
||||
|
||||
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
|
||||
@implements(osp_stats.gennorm.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, p: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, p))
|
||||
|
@ -17,12 +17,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax.scipy.special import xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.geom.logpmf, update_doc=False)
|
||||
@implements(osp_stats.geom.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
@ -32,6 +32,6 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.geom.pmf, update_doc=False)
|
||||
@implements(osp_stats.geom.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
|
@ -21,12 +21,12 @@ import scipy.stats as osp_stats
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, lax, random, vmap
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements
|
||||
from jax._src.tree_util import register_pytree_node_class
|
||||
from jax.scipy import linalg, special
|
||||
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde, update_doc=False)
|
||||
@register_pytree_node_class
|
||||
@dataclass(frozen=True, init=False)
|
||||
class gaussian_kde:
|
||||
@ -113,7 +113,7 @@ class gaussian_kde:
|
||||
def n(self):
|
||||
return self.dataset.shape[1]
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.evaluate, update_doc=False)
|
||||
def evaluate(self, points):
|
||||
check_arraylike("evaluate", points)
|
||||
points = self._reshape_points(points)
|
||||
@ -121,11 +121,11 @@ class gaussian_kde:
|
||||
points.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.__call__, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.__call__, update_doc=False)
|
||||
def __call__(self, points):
|
||||
return self.evaluate(points)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
|
||||
def integrate_gaussian(self, mean, cov):
|
||||
mean = jnp.atleast_1d(jnp.squeeze(mean))
|
||||
cov = jnp.atleast_2d(cov)
|
||||
@ -141,7 +141,7 @@ class gaussian_kde:
|
||||
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
|
||||
mean)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
|
||||
def integrate_box_1d(self, low, high):
|
||||
if self.d != 1:
|
||||
raise ValueError("integrate_box_1d() only handles 1D pdfs")
|
||||
@ -153,7 +153,7 @@ class gaussian_kde:
|
||||
high = jnp.squeeze((high - self.dataset) / sigma)
|
||||
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
|
||||
def integrate_kde(self, other):
|
||||
if other.d != self.d:
|
||||
raise ValueError("KDEs are not the same dimensionality")
|
||||
@ -189,11 +189,11 @@ class gaussian_kde:
|
||||
dtype=self.dataset.dtype).T
|
||||
return self.dataset[:, ind] + eps
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.pdf, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.pdf, update_doc=False)
|
||||
def pdf(self, x):
|
||||
return self.evaluate(x)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False)
|
||||
@implements(osp_stats.gaussian_kde.logpdf, update_doc=False)
|
||||
def logpdf(self, x):
|
||||
check_arraylike("logpdf", x)
|
||||
x = self._reshape_points(x)
|
||||
|
@ -16,11 +16,11 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.laplace.logpdf, update_doc=False)
|
||||
@implements(osp_stats.laplace.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
@ -28,12 +28,12 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.laplace.pdf, update_doc=False)
|
||||
@implements(osp_stats.laplace.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.laplace.cdf, update_doc=False)
|
||||
@implements(osp_stats.laplace.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
|
||||
half = _lax_const(x, 0.5)
|
||||
|
@ -18,11 +18,11 @@ from jax.scipy.special import expit, logit
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
|
||||
@implements(osp_stats.logistic.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
@ -31,30 +31,30 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.pdf, update_doc=False)
|
||||
@implements(osp_stats.logistic.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.ppf, update_doc=False)
|
||||
@implements(osp_stats.logistic.ppf, update_doc=False)
|
||||
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
|
||||
return lax.add(lax.mul(logit(x), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.sf, update_doc=False)
|
||||
@implements(osp_stats.logistic.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
|
||||
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.isf, update_doc=False)
|
||||
@implements(osp_stats.logistic.isf, update_doc=False)
|
||||
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
|
||||
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.cdf, update_doc=False)
|
||||
@implements(osp_stats.logistic.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
|
||||
return expit(lax.div(lax.sub(x, loc), scale))
|
||||
|
@ -16,12 +16,12 @@
|
||||
import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact, promote_args_numeric
|
||||
from jax._src.numpy.util import implements, promote_args_inexact, promote_args_numeric
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.multinomial.logpmf, update_doc=False)
|
||||
@implements(osp_stats.multinomial.logpmf, update_doc=False)
|
||||
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
"""JAX implementation of scipy.stats.multinomial.logpmf."""
|
||||
p, = promote_args_inexact("multinomial.logpmf", p)
|
||||
@ -34,7 +34,7 @@ def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -jnp.inf)
|
||||
|
||||
|
||||
@_wraps(osp_stats.multinomial.pmf, update_doc=False)
|
||||
@implements(osp_stats.multinomial.pmf, update_doc=False)
|
||||
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
"""JAX implementation of scipy.stats.multinomial.pmf."""
|
||||
return lax.exp(logpmf(x, n, p))
|
||||
|
@ -19,11 +19,11 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_dtypes_inexact
|
||||
from jax._src.numpy.util import implements, promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description="""
|
||||
@implements(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description="""
|
||||
In the JAX version, the `allow_singular` argument is not implemented.
|
||||
""")
|
||||
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
|
||||
@ -50,6 +50,6 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None =
|
||||
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
|
||||
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
|
||||
|
||||
@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False)
|
||||
@implements(osp_stats.multivariate_normal.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, mean, cov))
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
|
||||
@implements(osp_stats.nbinom.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.nbinom.logpmf."""
|
||||
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
|
||||
@ -37,7 +37,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
|
||||
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
|
||||
@implements(osp_stats.nbinom.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.nbinom.pmf."""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
|
@ -20,12 +20,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy import special
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf, update_doc=False)
|
||||
@implements(osp_stats.norm.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
scale_sqrd = lax.square(scale)
|
||||
@ -34,41 +34,41 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.pdf, update_doc=False)
|
||||
@implements(osp_stats.norm.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.cdf, update_doc=False)
|
||||
@implements(osp_stats.norm.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
|
||||
return special.ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logcdf, update_doc=False)
|
||||
@implements(osp_stats.norm.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
|
||||
# Cast required because custom_jvp return type is broken.
|
||||
return cast(Array, special.log_ndtr(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.ppf, update_doc=False)
|
||||
@implements(osp_stats.norm.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logsf, update_doc=False)
|
||||
@implements(osp_stats.norm.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.sf, update_doc=False)
|
||||
@implements(osp_stats.norm.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.isf, update_doc=False)
|
||||
@implements(osp_stats.norm.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
|
||||
|
@ -18,11 +18,11 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.pareto.logpdf, update_doc=False)
|
||||
@implements(osp_stats.pareto.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
@ -31,6 +31,6 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
||||
log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.pareto.pdf, update_doc=False)
|
||||
@implements(osp_stats.pareto.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, b, loc, scale))
|
||||
|
@ -18,12 +18,12 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import xlogy, gammaln, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.poisson.logpmf, update_doc=False)
|
||||
@implements(osp_stats.poisson.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
@ -31,11 +31,11 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
|
||||
return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.poisson.pmf, update_doc=False)
|
||||
@implements(osp_stats.poisson.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.exp(logpmf(k, mu, loc))
|
||||
|
||||
@_wraps(osp_stats.poisson.cdf, update_doc=False)
|
||||
@implements(osp_stats.poisson.cdf, update_doc=False)
|
||||
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
|
@ -18,11 +18,11 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.logpdf, update_doc=False)
|
||||
@implements(osp_stats.t.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
@ -37,6 +37,6 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.pdf, update_doc=False)
|
||||
@implements(osp_stats.t.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
@ -17,7 +17,7 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.scipy.stats import norm
|
||||
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
|
||||
|
||||
@ -69,7 +69,7 @@ def _log_gauss_mass(a, b):
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.logpdf, update_doc=False)
|
||||
def logpdf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
|
||||
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
|
||||
@ -80,23 +80,23 @@ def logpdf(x, a, b, loc=0, scale=1):
|
||||
return val
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.pdf, update_doc=False)
|
||||
def pdf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.logsf, update_doc=False)
|
||||
def logsf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||
return logcdf(-x, -b, -a, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.sf, update_doc=False)
|
||||
def sf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logsf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.logcdf, update_doc=False)
|
||||
def logcdf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
@ -113,6 +113,6 @@ def logcdf(x, a, b, loc=0, scale=1):
|
||||
return logcdf
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
|
||||
@implements(osp_stats.truncnorm.cdf, update_doc=False)
|
||||
def cdf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logcdf(x, a, b, loc, scale))
|
||||
|
@ -19,10 +19,10 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.numpy import where, inf, logical_or
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
|
||||
|
||||
@_wraps(osp_stats.uniform.logpdf, update_doc=False)
|
||||
@implements(osp_stats.uniform.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
|
||||
log_probs = lax.neg(lax.log(scale))
|
||||
@ -30,11 +30,11 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
lax.lt(x, loc)),
|
||||
-inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.uniform.pdf, update_doc=False)
|
||||
@implements(osp_stats.uniform.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
@_wraps(osp_stats.uniform.cdf, update_doc=False)
|
||||
@implements(osp_stats.uniform.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
|
||||
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
|
||||
@ -43,7 +43,7 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
@_wraps(osp_stats.uniform.ppf, update_doc=False)
|
||||
@implements(osp_stats.uniform.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
|
||||
return where(
|
||||
|
@ -17,15 +17,15 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@_wraps(osp_stats.vonmises.logpdf, update_doc=False)
|
||||
@implements(osp_stats.vonmises.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
|
||||
zero = _lax_const(kappa, 0)
|
||||
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan)
|
||||
|
||||
@_wraps(osp_stats.vonmises.pdf, update_doc=False)
|
||||
@implements(osp_stats.vonmises.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, kappa))
|
||||
|
@ -17,11 +17,11 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.numpy.util import implements, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
|
||||
@implements(osp_stats.wrapcauchy.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
|
||||
return jnp.where(
|
||||
@ -34,6 +34,6 @@ def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
|
||||
@implements(osp_stats.wrapcauchy.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, c))
|
||||
|
10
jax/_src/third_party/numpy/linalg.py
vendored
10
jax/_src/third_party/numpy/linalg.py
vendored
@ -2,7 +2,7 @@ import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax.numpy.linalg as la
|
||||
from jax._src.numpy.util import check_arraylike, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, implements
|
||||
|
||||
|
||||
def _isEmpty2d(arr):
|
||||
@ -39,7 +39,7 @@ def _assert2d(*arrays):
|
||||
'Array must be two-dimensional')
|
||||
|
||||
|
||||
@_wraps(np.linalg.cond)
|
||||
@implements(np.linalg.cond)
|
||||
def cond(x, p=None):
|
||||
check_arraylike('jnp.linalg.cond', x)
|
||||
_assertNoEmpty2d(x)
|
||||
@ -62,7 +62,7 @@ def cond(x, p=None):
|
||||
return r
|
||||
|
||||
|
||||
@_wraps(np.linalg.tensorinv)
|
||||
@implements(np.linalg.tensorinv)
|
||||
def tensorinv(a, ind=2):
|
||||
check_arraylike('jnp.linalg.tensorinv', a)
|
||||
a = jnp.asarray(a)
|
||||
@ -79,7 +79,7 @@ def tensorinv(a, ind=2):
|
||||
return ia.reshape(*invshape)
|
||||
|
||||
|
||||
@_wraps(np.linalg.tensorsolve)
|
||||
@implements(np.linalg.tensorsolve)
|
||||
def tensorsolve(a, b, axes=None):
|
||||
check_arraylike('jnp.linalg.tensorsolve', a, b)
|
||||
a = jnp.asarray(a)
|
||||
@ -108,7 +108,7 @@ def tensorsolve(a, b, axes=None):
|
||||
return res
|
||||
|
||||
|
||||
@_wraps(np.linalg.multi_dot)
|
||||
@implements(np.linalg.multi_dot)
|
||||
def multi_dot(arrays, *, precision=None):
|
||||
check_arraylike('jnp.linalg.multi_dot', *arrays)
|
||||
n = len(arrays)
|
||||
|
6
jax/_src/third_party/scipy/interpolate.py
vendored
6
jax/_src/third_party/scipy/interpolate.py
vendored
@ -4,7 +4,7 @@ import scipy.interpolate as osp_interpolate
|
||||
from jax.numpy import (asarray, broadcast_arrays, can_cast,
|
||||
empty, nan, searchsorted, where, zeros)
|
||||
from jax._src.tree_util import register_pytree_node
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements
|
||||
|
||||
|
||||
def _ndim_coords_from_arrays(points, ndim=None):
|
||||
@ -31,7 +31,7 @@ def _ndim_coords_from_arrays(points, ndim=None):
|
||||
return points
|
||||
|
||||
|
||||
@_wraps(
|
||||
@implements(
|
||||
osp_interpolate.RegularGridInterpolator,
|
||||
lax_description="""
|
||||
In the JAX version, `bounds_error` defaults to and must always be `False` since no
|
||||
@ -76,7 +76,7 @@ class RegularGridInterpolator:
|
||||
self.grid = tuple(asarray(p) for p in points)
|
||||
self.values = values
|
||||
|
||||
@_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
|
||||
@implements(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
|
||||
def __call__(self, xi, method=None):
|
||||
method = self.method if method is None else method
|
||||
if method not in ("linear", "nearest"):
|
||||
|
4
jax/_src/third_party/scipy/linalg.py
vendored
4
jax/_src/third_party/scipy/linalg.py
vendored
@ -7,7 +7,7 @@ import scipy.linalg
|
||||
from jax import jit, lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.linalg import norm
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import implements
|
||||
from jax._src.scipy.linalg import rsf2csf, schur
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
|
||||
@ -51,7 +51,7 @@ Additionally, unlike the SciPy implementation, when ``disp=True`` no warning
|
||||
will be printed if the error in the array output is estimated to be large.
|
||||
"""
|
||||
|
||||
@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
|
||||
@implements(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
|
||||
def funm(A: ArrayLike, func: Callable[[Array], Array],
|
||||
disp: bool = True) -> Array | tuple[Array, Array]:
|
||||
A_arr = jnp.asarray(A)
|
||||
|
@ -52,7 +52,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -5861,7 +5861,7 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
if jit:
|
||||
wrapped = jax.jit(wrapped)
|
||||
|
||||
wrapped = _wraps(orig, skip_params=['out'])(wrapped)
|
||||
wrapped = implements(orig, skip_params=['out'])(wrapped)
|
||||
doc = wrapped.__doc__
|
||||
|
||||
self.assertStartsWith(doc, "Example Docstring")
|
||||
|
Loading…
x
Reference in New Issue
Block a user