diff --git a/CHANGELOG.md b/CHANGELOG.md index d48e8db47..78ac98c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ Remember to align the itemized text with the first line of an item within a list `axis_groups`, `ShapedArray`, `ConcreteArray`, `AxisEnv`, `backend_compile`, and `XLAOp`. * from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`, - and `in1d`. + `trapz`, and `in1d`. * from {mod}`jax.scipy.linalg`: `tril` and `triu`. ## jaxlib 0.4.24 diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d52619e43..2a3eb5958 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -378,7 +378,6 @@ namespace; they are listed below. tile trace transpose - trapz tri tril tril_indices diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3b7e8bc16..889d7d4ba 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -325,23 +325,6 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) -@util._wraps(np.trapz) -@partial(jit, static_argnames=('axis',)) -def trapz(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: - if x is None: - util.check_arraylike('trapz', y) - y_arr, = util.promote_dtypes_inexact(y) - else: - util.check_arraylike('trapz', y, x) - y_arr, x_arr = util.promote_dtypes_inexact(y, x) - if x_arr.ndim == 1: - dx = diff(x_arr) - else: - dx = moveaxis(diff(x_arr, axis=axis), axis, -1) - y_arr = moveaxis(y_arr, axis, -1) - return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) - - @util._wraps(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index a1b9d074a..e17706af7 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -226,7 +226,6 @@ from jax._src.numpy.lax_numpy import ( tensordot as tensordot, tile as tile, trace as trace, - trapz as _deprecated_trapz, transpose as transpose, tri as tri, tril as tril, @@ -429,23 +428,3 @@ try: from numpy import issubsctype as _deprecated_issubsctype except ImportError: _deprecated_issubsctype = None - -# Deprecations - -_deprecations = { - # Added Aug 24, 2023 - "trapz": ( - "jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.", - _deprecated_trapz, - ), -} - -import typing -if typing.TYPE_CHECKING: - trapz = _deprecated_trapz -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _deprecated_trapz diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b1ff44046..50793254c 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -756,8 +756,6 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ... def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ..., dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... -def trapz(y: ArrayLike, x: Optional[ArrayLike] = ..., dx: ArrayLike = ..., - axis: int = ...) -> Array: ... def tri( N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... ) -> Array: ... diff --git a/pyproject.toml b/pyproject.toml index f1966bb9d..78a0ba38c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ filterwarnings = [ "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", "ignore:np.find_common_type is deprecated.*:DeprecationWarning", "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", - "ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize # start array_api_tests-related warnings "ignore:The numpy.array_api submodule is still experimental.*:UserWarning", diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3767ec563..62be14828 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1991,36 +1991,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - @jtu.sample_product( - [dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis) - for yshape, xshape, dx, axis in [ - ((10,), None, 1.0, -1), - ((3, 10), None, 2.0, -1), - ((3, 10), None, 3.0, -0), - ((10, 3), (10,), 1.0, -2), - ((3, 10), (10,), 1.0, -1), - ((3, 10), (3, 10), 1.0, -1), - ((2, 3, 10), (3, 10), 1.0, -2), - ] - ], - dtype=default_dtypes, - ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testTrapz(self, yshape, xshape, dtype, dx, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] - # TODO(jakevdp): numpy.trapz is removed in numpy 2.0 - np_fun = jtu.ignore_warning(category=DeprecationWarning)( - partial(np.trapz, dx=dx, axis=axis)) - jnp_fun = partial(jnp.trapz, dx=dx, axis=axis) - tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12, - dtypes.bfloat16: 4e-2}) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol, - check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, - check_dtypes=False) - @jtu.sample_product( dtype=default_dtypes, n=[0, 4],