mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead. Fixes https://github.com/google/jax/issues/17244
This commit is contained in:
parent
a454081390
commit
975dae34a4
@ -14,6 +14,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
|
||||
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
|
||||
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
|
||||
* Added {func}`jax.scipy.integrate.trapezoid`.
|
||||
* When not running under IPython: when an exception is raised, JAX now filters out the
|
||||
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
|
||||
that previously appeared.) This should produce much friendlier-looking tracebacks. See
|
||||
@ -44,6 +45,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
|
||||
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
|
||||
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
|
||||
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
|
||||
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
|
||||
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
|
||||
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
|
||||
|
@ -14,6 +14,16 @@ jax.scipy.fft
|
||||
idct
|
||||
idctn
|
||||
|
||||
jax.scipy.integrate
|
||||
-------------------
|
||||
|
||||
.. automodule:: jax.scipy.integrate
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
trapezoid
|
||||
|
||||
jax.scipy.linalg
|
||||
----------------
|
||||
|
||||
|
44
jax/_src/scipy/integrate.py
Normal file
44
jax/_src/scipy/integrate.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
import scipy.integrate
|
||||
|
||||
from jax import jit
|
||||
from jax._src.numpy import util
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import jax.numpy as jnp
|
||||
|
||||
@util._wraps(scipy.integrate.trapezoid)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||
dx_array: Array
|
||||
if x is None:
|
||||
util.check_arraylike('trapz', y)
|
||||
y_arr, = util.promote_dtypes_inexact(y)
|
||||
dx_array = jnp.asarray(dx)
|
||||
else:
|
||||
util.check_arraylike('trapz', y, x)
|
||||
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
||||
if x_arr.ndim == 1:
|
||||
dx_array = jnp.diff(x_arr)
|
||||
else:
|
||||
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
|
||||
y_arr = jnp.moveaxis(y_arr, axis, -1)
|
||||
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
@ -226,7 +226,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
tensordot as tensordot,
|
||||
tile as tile,
|
||||
trace as trace,
|
||||
trapz as trapz,
|
||||
trapz as _deprecated_trapz,
|
||||
transpose as transpose,
|
||||
tri as tri,
|
||||
tril as tril,
|
||||
@ -474,6 +474,11 @@ _deprecations = {
|
||||
"jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.",
|
||||
_deprecated_in1d,
|
||||
),
|
||||
# Added Aug 24, 2023
|
||||
"trapz": (
|
||||
"jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.",
|
||||
_deprecated_trapz,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
@ -488,6 +493,7 @@ if typing.TYPE_CHECKING:
|
||||
PZERO = 0.0
|
||||
issubsctype = _numpy.core.numerictypes.issubsctype
|
||||
in1d = _deprecated_in1d
|
||||
trapz = _deprecated_trapz
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
@ -496,3 +502,4 @@ del typing
|
||||
del _numpy
|
||||
|
||||
del _deprecated_in1d
|
||||
del _deprecated_trapz
|
||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from jax.scipy import stats as stats
|
||||
from jax.scipy import fft as fft
|
||||
from jax.scipy import cluster as cluster
|
||||
from jax.scipy import integrate as integrate
|
||||
else:
|
||||
import jax._src.lazy_loader as _lazy
|
||||
__getattr__, __dir__, __all__ = _lazy.attach(__name__, [
|
||||
@ -39,6 +40,7 @@ else:
|
||||
"stats",
|
||||
"fft",
|
||||
"cluster",
|
||||
"integrate",
|
||||
])
|
||||
del _lazy
|
||||
|
||||
|
20
jax/scipy/integrate.py
Normal file
20
jax/scipy/integrate.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.integrate import (
|
||||
trapezoid as trapezoid
|
||||
)
|
@ -73,6 +73,7 @@ filterwarnings = [
|
||||
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
|
||||
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
|
||||
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
|
||||
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
@ -20,10 +20,12 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
import scipy.integrate
|
||||
import scipy.special as osp_special
|
||||
import scipy.cluster as osp_cluster
|
||||
|
||||
import jax
|
||||
import jax.dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
from jax import scipy as jsp
|
||||
@ -542,5 +544,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(actual, nan_array, check_dtypes=False)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
|
||||
for yshape, xshape, dx, axis in [
|
||||
((10,), None, 1.0, -1),
|
||||
((3, 10), None, 2.0, -1),
|
||||
((3, 10), None, 3.0, -0),
|
||||
((10, 3), (10,), 1.0, -2),
|
||||
((3, 10), (10,), 1.0, -1),
|
||||
((3, 10), (3, 10), 1.0, -1),
|
||||
((2, 3, 10), (3, 10), 1.0, -2),
|
||||
]
|
||||
],
|
||||
dtype=float_dtypes + int_dtypes,
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
||||
np_fun = partial(scipy.integrate.trapezoid, dx=dx, axis=axis)
|
||||
jnp_fun = partial(jax.scipy.integrate.trapezoid, dx=dx, axis=axis)
|
||||
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
||||
jax.dtypes.bfloat16: 4e-2})
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
||||
check_dtypes=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user