mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead. Fixes https://github.com/google/jax/issues/17244
This commit is contained in:
parent
a454081390
commit
975dae34a4
@ -14,6 +14,7 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
|
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
|
||||||
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
|
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
|
||||||
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
|
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
|
||||||
|
* Added {func}`jax.scipy.integrate.trapezoid`.
|
||||||
* When not running under IPython: when an exception is raised, JAX now filters out the
|
* When not running under IPython: when an exception is raised, JAX now filters out the
|
||||||
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
|
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
|
||||||
that previously appeared.) This should produce much friendlier-looking tracebacks. See
|
that previously appeared.) This should produce much friendlier-looking tracebacks. See
|
||||||
@ -44,6 +45,7 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
|
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
|
||||||
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
|
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
|
||||||
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
|
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
|
||||||
|
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
|
||||||
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
|
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
|
||||||
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
|
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
|
||||||
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
|
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
|
||||||
|
@ -14,6 +14,16 @@ jax.scipy.fft
|
|||||||
idct
|
idct
|
||||||
idctn
|
idctn
|
||||||
|
|
||||||
|
jax.scipy.integrate
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
.. automodule:: jax.scipy.integrate
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
trapezoid
|
||||||
|
|
||||||
jax.scipy.linalg
|
jax.scipy.linalg
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
|
44
jax/_src/scipy/integrate.py
Normal file
44
jax/_src/scipy/integrate.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2023 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import scipy.integrate
|
||||||
|
|
||||||
|
from jax import jit
|
||||||
|
from jax._src.numpy import util
|
||||||
|
from jax._src.typing import Array, ArrayLike
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
@util._wraps(scipy.integrate.trapezoid)
|
||||||
|
@partial(jit, static_argnames=('axis',))
|
||||||
|
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||||
|
axis: int = -1) -> Array:
|
||||||
|
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||||
|
dx_array: Array
|
||||||
|
if x is None:
|
||||||
|
util.check_arraylike('trapz', y)
|
||||||
|
y_arr, = util.promote_dtypes_inexact(y)
|
||||||
|
dx_array = jnp.asarray(dx)
|
||||||
|
else:
|
||||||
|
util.check_arraylike('trapz', y, x)
|
||||||
|
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
|
||||||
|
if x_arr.ndim == 1:
|
||||||
|
dx_array = jnp.diff(x_arr)
|
||||||
|
else:
|
||||||
|
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
|
||||||
|
y_arr = jnp.moveaxis(y_arr, axis, -1)
|
||||||
|
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
|
@ -226,7 +226,7 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
tensordot as tensordot,
|
tensordot as tensordot,
|
||||||
tile as tile,
|
tile as tile,
|
||||||
trace as trace,
|
trace as trace,
|
||||||
trapz as trapz,
|
trapz as _deprecated_trapz,
|
||||||
transpose as transpose,
|
transpose as transpose,
|
||||||
tri as tri,
|
tri as tri,
|
||||||
tril as tril,
|
tril as tril,
|
||||||
@ -474,6 +474,11 @@ _deprecations = {
|
|||||||
"jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.",
|
"jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.",
|
||||||
_deprecated_in1d,
|
_deprecated_in1d,
|
||||||
),
|
),
|
||||||
|
# Added Aug 24, 2023
|
||||||
|
"trapz": (
|
||||||
|
"jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.",
|
||||||
|
_deprecated_trapz,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
@ -488,6 +493,7 @@ if typing.TYPE_CHECKING:
|
|||||||
PZERO = 0.0
|
PZERO = 0.0
|
||||||
issubsctype = _numpy.core.numerictypes.issubsctype
|
issubsctype = _numpy.core.numerictypes.issubsctype
|
||||||
in1d = _deprecated_in1d
|
in1d = _deprecated_in1d
|
||||||
|
trapz = _deprecated_trapz
|
||||||
else:
|
else:
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
@ -496,3 +502,4 @@ del typing
|
|||||||
del _numpy
|
del _numpy
|
||||||
|
|
||||||
del _deprecated_in1d
|
del _deprecated_in1d
|
||||||
|
del _deprecated_trapz
|
||||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from jax.scipy import stats as stats
|
from jax.scipy import stats as stats
|
||||||
from jax.scipy import fft as fft
|
from jax.scipy import fft as fft
|
||||||
from jax.scipy import cluster as cluster
|
from jax.scipy import cluster as cluster
|
||||||
|
from jax.scipy import integrate as integrate
|
||||||
else:
|
else:
|
||||||
import jax._src.lazy_loader as _lazy
|
import jax._src.lazy_loader as _lazy
|
||||||
__getattr__, __dir__, __all__ = _lazy.attach(__name__, [
|
__getattr__, __dir__, __all__ = _lazy.attach(__name__, [
|
||||||
@ -39,6 +40,7 @@ else:
|
|||||||
"stats",
|
"stats",
|
||||||
"fft",
|
"fft",
|
||||||
"cluster",
|
"cluster",
|
||||||
|
"integrate",
|
||||||
])
|
])
|
||||||
del _lazy
|
del _lazy
|
||||||
|
|
||||||
|
20
jax/scipy/integrate.py
Normal file
20
jax/scipy/integrate.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2023 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Note: import <name> as <name> is required for names to be exported.
|
||||||
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
|
from jax._src.scipy.integrate import (
|
||||||
|
trapezoid as trapezoid
|
||||||
|
)
|
@ -73,6 +73,7 @@ filterwarnings = [
|
|||||||
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
|
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
|
||||||
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
|
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
|
||||||
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
|
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
|
||||||
|
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
|
||||||
]
|
]
|
||||||
doctest_optionflags = [
|
doctest_optionflags = [
|
||||||
"NUMBER",
|
"NUMBER",
|
||||||
|
@ -20,10 +20,12 @@ import unittest
|
|||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.integrate
|
||||||
import scipy.special as osp_special
|
import scipy.special as osp_special
|
||||||
import scipy.cluster as osp_cluster
|
import scipy.cluster as osp_cluster
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
import jax.dtypes
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import scipy as jsp
|
from jax import scipy as jsp
|
||||||
@ -542,5 +544,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
self.assertArraysEqual(actual, nan_array, check_dtypes=False)
|
self.assertArraysEqual(actual, nan_array, check_dtypes=False)
|
||||||
|
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
|
||||||
|
for yshape, xshape, dx, axis in [
|
||||||
|
((10,), None, 1.0, -1),
|
||||||
|
((3, 10), None, 2.0, -1),
|
||||||
|
((3, 10), None, 3.0, -0),
|
||||||
|
((10, 3), (10,), 1.0, -2),
|
||||||
|
((3, 10), (10,), 1.0, -1),
|
||||||
|
((3, 10), (3, 10), 1.0, -1),
|
||||||
|
((2, 3, 10), (3, 10), 1.0, -2),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=float_dtypes + int_dtypes,
|
||||||
|
)
|
||||||
|
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
|
||||||
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||||
|
def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
||||||
|
np_fun = partial(scipy.integrate.trapezoid, dx=dx, axis=axis)
|
||||||
|
jnp_fun = partial(jax.scipy.integrate.trapezoid, dx=dx, axis=axis)
|
||||||
|
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
||||||
|
jax.dtypes.bfloat16: 4e-2})
|
||||||
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
|
||||||
|
check_dtypes=False)
|
||||||
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
||||||
|
check_dtypes=False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user