mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement for the already deprecated trapz function.
This commit is contained in:
parent
011ced4431
commit
9e01afe7af
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.26
|
## jax 0.4.26
|
||||||
|
|
||||||
|
* New Functionality
|
||||||
|
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
|
||||||
|
NumPy 2.0.
|
||||||
|
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
||||||
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
||||||
|
@ -391,6 +391,7 @@ namespace; they are listed below.
|
|||||||
tensordot
|
tensordot
|
||||||
tile
|
tile
|
||||||
trace
|
trace
|
||||||
|
trapezoid
|
||||||
transpose
|
transpose
|
||||||
tri
|
tri
|
||||||
tril
|
tril
|
||||||
|
@ -2877,6 +2877,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
|||||||
return take(a, gather_indices, axis=axis)
|
return take(a, gather_indices, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
|
||||||
|
@partial(jit, static_argnames=('axis',))
|
||||||
|
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||||
|
axis: int = -1) -> Array:
|
||||||
|
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||||
|
dx_array: Array
|
||||||
|
if x is None:
|
||||||
|
util.check_arraylike('trapezoid', y)
|
||||||
|
y_arr, = util.promote_dtypes_inexact(y)
|
||||||
|
dx_array = asarray(dx)
|
||||||
|
else:
|
||||||
|
util.check_arraylike('trapezoid', y, x)
|
||||||
|
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
||||||
|
if x_arr.ndim == 1:
|
||||||
|
dx_array = diff(x_arr)
|
||||||
|
else:
|
||||||
|
dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1)
|
||||||
|
y_arr = moveaxis(y_arr, axis, -1)
|
||||||
|
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
||||||
|
|
||||||
|
|
||||||
@util.implements(np.tri)
|
@util.implements(np.tri)
|
||||||
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
|
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
|
||||||
dtypes.check_user_dtype_supported(dtype, "tri")
|
dtypes.check_user_dtype_supported(dtype, "tri")
|
||||||
|
@ -27,18 +27,4 @@ import jax.numpy as jnp
|
|||||||
@partial(jit, static_argnames=('axis',))
|
@partial(jit, static_argnames=('axis',))
|
||||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||||
axis: int = -1) -> Array:
|
axis: int = -1) -> Array:
|
||||||
# TODO(phawkins): remove this annotation after fixing jnp types.
|
return jnp.trapezoid(y, x, dx, axis)
|
||||||
dx_array: Array
|
|
||||||
if x is None:
|
|
||||||
util.check_arraylike('trapezoid', y)
|
|
||||||
y_arr, = util.promote_dtypes_inexact(y)
|
|
||||||
dx_array = jnp.asarray(dx)
|
|
||||||
else:
|
|
||||||
util.check_arraylike('trapezoid', y, x)
|
|
||||||
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
|
||||||
if x_arr.ndim == 1:
|
|
||||||
dx_array = jnp.diff(x_arr)
|
|
||||||
else:
|
|
||||||
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
|
|
||||||
y_arr = jnp.moveaxis(y_arr, axis, -1)
|
|
||||||
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
|
||||||
|
@ -233,6 +233,7 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
tensordot as tensordot,
|
tensordot as tensordot,
|
||||||
tile as tile,
|
tile as tile,
|
||||||
trace as trace,
|
trace as trace,
|
||||||
|
trapezoid as trapezoid,
|
||||||
transpose as transpose,
|
transpose as transpose,
|
||||||
tri as tri,
|
tri as tri,
|
||||||
tril as tril,
|
tril as tril,
|
||||||
@ -447,7 +448,15 @@ from jax._src.numpy.array_methods import register_jax_array_methods
|
|||||||
register_jax_array_methods()
|
register_jax_array_methods()
|
||||||
del register_jax_array_methods
|
del register_jax_array_methods
|
||||||
|
|
||||||
try:
|
|
||||||
from numpy import issubsctype as _deprecated_issubsctype
|
_deprecations = {
|
||||||
except ImportError:
|
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
|
||||||
_deprecated_issubsctype = None
|
"trapz": (
|
||||||
|
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
|
||||||
|
None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
|
del _deprecation_getattr
|
||||||
|
@ -788,6 +788,8 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
|
|||||||
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
|
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
|
||||||
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
|
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
|
||||||
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
|
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
|
||||||
|
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ...,
|
||||||
|
axis: int = ...) -> Array: ...
|
||||||
def tri(
|
def tri(
|
||||||
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
|
N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ...
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
|
@ -5571,6 +5571,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
numpy_result = np.isdtype(dtype, kind)
|
numpy_result = np.isdtype(dtype, kind)
|
||||||
self.assertEqual(jax_result, numpy_result)
|
self.assertEqual(jax_result, numpy_result)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
|
||||||
|
for yshape, xshape, dx, axis in [
|
||||||
|
((10,), None, 1.0, -1),
|
||||||
|
((3, 10), None, 2.0, -1),
|
||||||
|
((3, 10), None, 3.0, -0),
|
||||||
|
((10, 3), (10,), 1.0, -2),
|
||||||
|
((3, 10), (10,), 1.0, -1),
|
||||||
|
((3, 10), (3, 10), 1.0, -1),
|
||||||
|
((2, 3, 10), (3, 10), 1.0, -2),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=float_dtypes + int_dtypes,
|
||||||
|
)
|
||||||
|
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
|
||||||
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||||
|
def test_trapezoid(self, yshape, xshape, dtype, dx, axis):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
||||||
|
if jtu.numpy_version() >= (2, 0, 0):
|
||||||
|
np_fun = partial(np.trapezoid, dx=dx, axis=axis)
|
||||||
|
else:
|
||||||
|
np_fun = partial(np.trapz, dx=dx, axis=axis)
|
||||||
|
jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis)
|
||||||
|
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
||||||
|
jax.dtypes.bfloat16: 4e-2})
|
||||||
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
|
||||||
|
check_dtypes=False)
|
||||||
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
||||||
|
check_dtypes=False)
|
||||||
|
|
||||||
|
|
||||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||||
# as needed for e.g. particular compound ops of interest.
|
# as needed for e.g. particular compound ops of interest.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user