diff --git a/CHANGELOG.md b/CHANGELOG.md index 405c20923..6d393fe31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.26 +* New Functionality + * Added {func}`jax.numpy.trapezoid`, following the addition of this function in + NumPy 2.0. + * Deprecations & Removals * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index cdc557477..41c87603b 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -391,6 +391,7 @@ namespace; they are listed below. tensordot tile trace + trapezoid transpose tri tril diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1fdda450e..36175fa9a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2877,6 +2877,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None))) +@partial(jit, static_argnames=('axis',)) +def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, + axis: int = -1) -> Array: + # TODO(phawkins): remove this annotation after fixing jnp types. + dx_array: Array + if x is None: + util.check_arraylike('trapezoid', y) + y_arr, = util.promote_dtypes_inexact(y) + dx_array = asarray(dx) + else: + util.check_arraylike('trapezoid', y, x) + y_arr, x_arr = util.promote_dtypes_inexact(y, x) + if x_arr.ndim == 1: + dx_array = diff(x_arr) + else: + dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1) + y_arr = moveaxis(y_arr, axis, -1) + return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) + + @util.implements(np.tri) def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py index 97cfe0ff1..e60d8a06b 100644 --- a/jax/_src/scipy/integrate.py +++ b/jax/_src/scipy/integrate.py @@ -27,18 +27,4 @@ import jax.numpy as jnp @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: - # TODO(phawkins): remove this annotation after fixing jnp types. - dx_array: Array - if x is None: - util.check_arraylike('trapezoid', y) - y_arr, = util.promote_dtypes_inexact(y) - dx_array = jnp.asarray(dx) - else: - util.check_arraylike('trapezoid', y, x) - y_arr, x_arr = util.promote_dtypes_inexact(y, x) - if x_arr.ndim == 1: - dx_array = jnp.diff(x_arr) - else: - dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1) - y_arr = jnp.moveaxis(y_arr, axis, -1) - return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) + return jnp.trapezoid(y, x, dx, axis) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 72b2ed7f9..273c5a2aa 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -233,6 +233,7 @@ from jax._src.numpy.lax_numpy import ( tensordot as tensordot, tile as tile, trace as trace, + trapezoid as trapezoid, transpose as transpose, tri as tri, tril as tril, @@ -447,7 +448,15 @@ from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods() del register_jax_array_methods -try: - from numpy import issubsctype as _deprecated_issubsctype -except ImportError: - _deprecated_issubsctype = None + +_deprecations = { + # Deprecated 18 Sept 2023 and removed 06 Feb 2024 + "trapz": ( + "jnp.trapz is deprecated; use jnp.trapezoid instead.", + None + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index a618f4570..9ed5f39b3 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -788,6 +788,8 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ... def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ..., dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... +def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., + axis: int = ...) -> Array: ... def tri( N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... ) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ed3a16eda..6aa8fb9d8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5571,6 +5571,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): numpy_result = np.isdtype(dtype, kind) self.assertEqual(jax_result, numpy_result) + @jtu.sample_product( + [dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis) + for yshape, xshape, dx, axis in [ + ((10,), None, 1.0, -1), + ((3, 10), None, 2.0, -1), + ((3, 10), None, 3.0, -0), + ((10, 3), (10,), 1.0, -2), + ((3, 10), (10,), 1.0, -1), + ((3, 10), (3, 10), 1.0, -1), + ((2, 3, 10), (3, 10), 1.0, -2), + ] + ], + dtype=float_dtypes + int_dtypes, + ) + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_trapezoid(self, yshape, xshape, dtype, dx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.trapezoid, dx=dx, axis=axis) + else: + np_fun = partial(np.trapz, dx=dx, axis=axis) + jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis) + tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12, + jax.dtypes.bfloat16: 4e-2}) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol, + check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, + check_dtypes=False) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest.