mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement for the already deprecated trapz function.
This commit is contained in:
parent
011ced4431
commit
9e01afe7af
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.26
|
||||
|
||||
* New Functionality
|
||||
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
|
||||
NumPy 2.0.
|
||||
|
||||
* Deprecations & Removals
|
||||
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
||||
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
||||
|
@ -391,6 +391,7 @@ namespace; they are listed below.
|
||||
tensordot
|
||||
tile
|
||||
trace
|
||||
trapezoid
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
|
@ -2877,6 +2877,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
return take(a, gather_indices, axis=axis)
|
||||
|
||||
|
||||
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||
dx_array: Array
|
||||
if x is None:
|
||||
util.check_arraylike('trapezoid', y)
|
||||
y_arr, = util.promote_dtypes_inexact(y)
|
||||
dx_array = asarray(dx)
|
||||
else:
|
||||
util.check_arraylike('trapezoid', y, x)
|
||||
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
||||
if x_arr.ndim == 1:
|
||||
dx_array = diff(x_arr)
|
||||
else:
|
||||
dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1)
|
||||
y_arr = moveaxis(y_arr, axis, -1)
|
||||
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
||||
|
||||
|
||||
@util.implements(np.tri)
|
||||
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "tri")
|
||||
|
@ -27,18 +27,4 @@ import jax.numpy as jnp
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||
dx_array: Array
|
||||
if x is None:
|
||||
util.check_arraylike('trapezoid', y)
|
||||
y_arr, = util.promote_dtypes_inexact(y)
|
||||
dx_array = jnp.asarray(dx)
|
||||
else:
|
||||
util.check_arraylike('trapezoid', y, x)
|
||||
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
||||
if x_arr.ndim == 1:
|
||||
dx_array = jnp.diff(x_arr)
|
||||
else:
|
||||
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
|
||||
y_arr = jnp.moveaxis(y_arr, axis, -1)
|
||||
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
||||
return jnp.trapezoid(y, x, dx, axis)
|
||||
|
@ -233,6 +233,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
tensordot as tensordot,
|
||||
tile as tile,
|
||||
trace as trace,
|
||||
trapezoid as trapezoid,
|
||||
transpose as transpose,
|
||||
tri as tri,
|
||||
tril as tril,
|
||||
@ -447,7 +448,15 @@ from jax._src.numpy.array_methods import register_jax_array_methods
|
||||
register_jax_array_methods()
|
||||
del register_jax_array_methods
|
||||
|
||||
try:
|
||||
from numpy import issubsctype as _deprecated_issubsctype
|
||||
except ImportError:
|
||||
_deprecated_issubsctype = None
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
|
||||
"trapz": (
|
||||
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
|
||||
None
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -788,6 +788,8 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
|
||||
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
|
||||
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
|
||||
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ...,
|
||||
axis: int = ...) -> Array: ...
|
||||
def tri(
|
||||
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
|
||||
) -> Array: ...
|
||||
|
@ -5571,6 +5571,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
numpy_result = np.isdtype(dtype, kind)
|
||||
self.assertEqual(jax_result, numpy_result)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
|
||||
for yshape, xshape, dx, axis in [
|
||||
((10,), None, 1.0, -1),
|
||||
((3, 10), None, 2.0, -1),
|
||||
((3, 10), None, 3.0, -0),
|
||||
((10, 3), (10,), 1.0, -2),
|
||||
((3, 10), (10,), 1.0, -1),
|
||||
((3, 10), (3, 10), 1.0, -1),
|
||||
((2, 3, 10), (3, 10), 1.0, -2),
|
||||
]
|
||||
],
|
||||
dtype=float_dtypes + int_dtypes,
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def test_trapezoid(self, yshape, xshape, dtype, dx, axis):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
||||
if jtu.numpy_version() >= (2, 0, 0):
|
||||
np_fun = partial(np.trapezoid, dx=dx, axis=axis)
|
||||
else:
|
||||
np_fun = partial(np.trapz, dx=dx, axis=axis)
|
||||
jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis)
|
||||
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
||||
jax.dtypes.bfloat16: 4e-2})
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
||||
check_dtypes=False)
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
Loading…
x
Reference in New Issue
Block a user