Remove deprecated function jax.numpy.trapz

This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
This commit is contained in:
Jake VanderPlas 2023-12-19 09:57:01 -08:00 committed by jax authors
parent c172be1379
commit cab63114b4
7 changed files with 1 additions and 73 deletions

View File

@ -32,7 +32,7 @@ Remember to align the itemized text with the first line of an item within a list
`axis_groups`, `ShapedArray`, `ConcreteArray`, `AxisEnv`, `backend_compile`,
and `XLAOp`.
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
and `in1d`.
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
## jaxlib 0.4.24

View File

@ -378,7 +378,6 @@ namespace; they are listed below.
tile
trace
transpose
trapz
tri
tril
tril_indices

View File

@ -325,23 +325,6 @@ def result_type(*args: Any) -> DType:
return dtypes.result_type(*args)
@util._wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array:
if x is None:
util.check_arraylike('trapz', y)
y_arr, = util.promote_dtypes_inexact(y)
else:
util.check_arraylike('trapz', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx = diff(x_arr)
else:
dx = moveaxis(diff(x_arr, axis=axis), axis, -1)
y_arr = moveaxis(y_arr, axis, -1)
return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@util._wraps(np.trunc, module='numpy')
@jit
def trunc(x: ArrayLike) -> Array:

View File

@ -226,7 +226,6 @@ from jax._src.numpy.lax_numpy import (
tensordot as tensordot,
tile as tile,
trace as trace,
trapz as _deprecated_trapz,
transpose as transpose,
tri as tri,
tril as tril,
@ -429,23 +428,3 @@ try:
from numpy import issubsctype as _deprecated_issubsctype
except ImportError:
_deprecated_issubsctype = None
# Deprecations
_deprecations = {
# Added Aug 24, 2023
"trapz": (
"jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.",
_deprecated_trapz,
),
}
import typing
if typing.TYPE_CHECKING:
trapz = _deprecated_trapz
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _deprecated_trapz

View File

@ -756,8 +756,6 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
def trapz(y: ArrayLike, x: Optional[ArrayLike] = ..., dx: ArrayLike = ...,
axis: int = ...) -> Array: ...
def tri(
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
) -> Array: ...

View File

@ -76,7 +76,6 @@ filterwarnings = [
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings
"ignore:The numpy.array_api submodule is still experimental.*:UserWarning",

View File

@ -1991,36 +1991,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=default_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testTrapz(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
# TODO(jakevdp): numpy.trapz is removed in numpy 2.0
np_fun = jtu.ignore_warning(category=DeprecationWarning)(
partial(np.trapz, dx=dx, axis=axis))
jnp_fun = partial(jnp.trapz, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
@jtu.sample_product(
dtype=default_dtypes,
n=[0, 4],