Add jax.numpy.trapezoid

This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
This commit is contained in:
Jake VanderPlas 2024-04-01 13:05:20 -07:00
parent 011ced4431
commit 9e01afe7af
7 changed files with 73 additions and 19 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.26 ## jax 0.4.26
* New Functionality
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
NumPy 2.0.
* Deprecations & Removals * Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.

View File

@ -391,6 +391,7 @@ namespace; they are listed below.
tensordot tensordot
tile tile
trace trace
trapezoid
transpose transpose
tri tri
tril tril

View File

@ -2877,6 +2877,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
return take(a, gather_indices, axis=axis) return take(a, gather_indices, axis=axis)
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None:
util.check_arraylike('trapezoid', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = asarray(dx)
else:
util.check_arraylike('trapezoid', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = diff(x_arr)
else:
dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1)
y_arr = moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@util.implements(np.tri) @util.implements(np.tri)
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "tri") dtypes.check_user_dtype_supported(dtype, "tri")

View File

@ -27,18 +27,4 @@ import jax.numpy as jnp
@partial(jit, static_argnames=('axis',)) @partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array: axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types. return jnp.trapezoid(y, x, dx, axis)
dx_array: Array
if x is None:
util.check_arraylike('trapezoid', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = jnp.asarray(dx)
else:
util.check_arraylike('trapezoid', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = jnp.diff(x_arr)
else:
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
y_arr = jnp.moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)

View File

@ -233,6 +233,7 @@ from jax._src.numpy.lax_numpy import (
tensordot as tensordot, tensordot as tensordot,
tile as tile, tile as tile,
trace as trace, trace as trace,
trapezoid as trapezoid,
transpose as transpose, transpose as transpose,
tri as tri, tri as tri,
tril as tril, tril as tril,
@ -447,7 +448,15 @@ from jax._src.numpy.array_methods import register_jax_array_methods
register_jax_array_methods() register_jax_array_methods()
del register_jax_array_methods del register_jax_array_methods
try:
from numpy import issubsctype as _deprecated_issubsctype _deprecations = {
except ImportError: # Deprecated 18 Sept 2023 and removed 06 Feb 2024
_deprecated_issubsctype = None "trapz": (
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
None
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -788,6 +788,8 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ..., def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ...,
axis: int = ...) -> Array: ...
def tri( def tri(
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
) -> Array: ... ) -> Array: ...

View File

@ -5571,6 +5571,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
numpy_result = np.isdtype(dtype, kind) numpy_result = np.isdtype(dtype, kind)
self.assertEqual(jax_result, numpy_result) self.assertEqual(jax_result, numpy_result)
@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=float_dtypes + int_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_trapezoid(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.trapezoid, dx=dx, axis=axis)
else:
np_fun = partial(np.trapz, dx=dx, axis=axis)
jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
jax.dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
# Most grad tests are at the lax level (see lax_test.py), but we add some here # Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest. # as needed for e.g. particular compound ops of interest.