Add jax.numpy.trapezoid

This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
This commit is contained in:
Jake VanderPlas 2024-04-01 13:05:20 -07:00
parent 011ced4431
commit 9e01afe7af
7 changed files with 73 additions and 19 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.26
* New Functionality
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
NumPy 2.0.
* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.

View File

@ -391,6 +391,7 @@ namespace; they are listed below.
tensordot
tile
trace
trapezoid
transpose
tri
tril

View File

@ -2877,6 +2877,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
return take(a, gather_indices, axis=axis)
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None:
util.check_arraylike('trapezoid', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = asarray(dx)
else:
util.check_arraylike('trapezoid', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = diff(x_arr)
else:
dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1)
y_arr = moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@util.implements(np.tri)
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "tri")

View File

@ -27,18 +27,4 @@ import jax.numpy as jnp
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None:
util.check_arraylike('trapezoid', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = jnp.asarray(dx)
else:
util.check_arraylike('trapezoid', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = jnp.diff(x_arr)
else:
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
y_arr = jnp.moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
return jnp.trapezoid(y, x, dx, axis)

View File

@ -233,6 +233,7 @@ from jax._src.numpy.lax_numpy import (
tensordot as tensordot,
tile as tile,
trace as trace,
trapezoid as trapezoid,
transpose as transpose,
tri as tri,
tril as tril,
@ -447,7 +448,15 @@ from jax._src.numpy.array_methods import register_jax_array_methods
register_jax_array_methods()
del register_jax_array_methods
try:
from numpy import issubsctype as _deprecated_issubsctype
except ImportError:
_deprecated_issubsctype = None
_deprecations = {
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
"trapz": (
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
None
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -788,6 +788,8 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ...,
axis: int = ...) -> Array: ...
def tri(
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
) -> Array: ...

View File

@ -5571,6 +5571,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
numpy_result = np.isdtype(dtype, kind)
self.assertEqual(jax_result, numpy_result)
@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=float_dtypes + int_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_trapezoid(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.trapezoid, dx=dx, axis=axis)
else:
np_fun = partial(np.trapz, dx=dx, axis=axis)
jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
jax.dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.