Deprecate jax.numpy.trapz.

Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
This commit is contained in:
Peter Hawkins 2023-08-24 14:01:40 -06:00
parent a454081390
commit 975dae34a4
8 changed files with 117 additions and 1 deletions

View File

@ -14,6 +14,7 @@ Remember to align the itemized text with the first line of an item within a list
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`, {meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and {meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`). {meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the * When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See that previously appeared.) This should produce much friendlier-looking tracebacks. See
@ -44,6 +45,7 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`. * `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead. * `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead. * `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated, * `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead. following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11. * `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.

View File

@ -14,6 +14,16 @@ jax.scipy.fft
idct idct
idctn idctn
jax.scipy.integrate
-------------------
.. automodule:: jax.scipy.integrate
.. autosummary::
:toctree: _autosummary
trapezoid
jax.scipy.linalg jax.scipy.linalg
---------------- ----------------

View File

@ -0,0 +1,44 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import scipy.integrate
from jax import jit
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike
import jax.numpy as jnp
@util._wraps(scipy.integrate.trapezoid)
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None:
util.check_arraylike('trapz', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = jnp.asarray(dx)
else:
util.check_arraylike('trapz', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = jnp.diff(x_arr)
else:
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
y_arr = jnp.moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)

View File

@ -226,7 +226,7 @@ from jax._src.numpy.lax_numpy import (
tensordot as tensordot, tensordot as tensordot,
tile as tile, tile as tile,
trace as trace, trace as trace,
trapz as trapz, trapz as _deprecated_trapz,
transpose as transpose, transpose as transpose,
tri as tri, tri as tri,
tril as tril, tril as tril,
@ -474,6 +474,11 @@ _deprecations = {
"jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.", "jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.",
_deprecated_in1d, _deprecated_in1d,
), ),
# Added Aug 24, 2023
"trapz": (
"jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.",
_deprecated_trapz,
),
} }
import typing import typing
@ -488,6 +493,7 @@ if typing.TYPE_CHECKING:
PZERO = 0.0 PZERO = 0.0
issubsctype = _numpy.core.numerictypes.issubsctype issubsctype = _numpy.core.numerictypes.issubsctype
in1d = _deprecated_in1d in1d = _deprecated_in1d
trapz = _deprecated_trapz
else: else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations) __getattr__ = _deprecation_getattr(__name__, _deprecations)
@ -496,3 +502,4 @@ del typing
del _numpy del _numpy
del _deprecated_in1d del _deprecated_in1d
del _deprecated_trapz

View File

@ -27,6 +27,7 @@ if TYPE_CHECKING:
from jax.scipy import stats as stats from jax.scipy import stats as stats
from jax.scipy import fft as fft from jax.scipy import fft as fft
from jax.scipy import cluster as cluster from jax.scipy import cluster as cluster
from jax.scipy import integrate as integrate
else: else:
import jax._src.lazy_loader as _lazy import jax._src.lazy_loader as _lazy
__getattr__, __dir__, __all__ = _lazy.attach(__name__, [ __getattr__, __dir__, __all__ = _lazy.attach(__name__, [
@ -39,6 +40,7 @@ else:
"stats", "stats",
"fft", "fft",
"cluster", "cluster",
"integrate",
]) ])
del _lazy del _lazy

20
jax/scipy/integrate.py Normal file
View File

@ -0,0 +1,20 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.integrate import (
trapezoid as trapezoid
)

View File

@ -73,6 +73,7 @@ filterwarnings = [
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
"ignore:np.find_common_type is deprecated.*:DeprecationWarning", "ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
] ]
doctest_optionflags = [ doctest_optionflags = [
"NUMBER", "NUMBER",

View File

@ -20,10 +20,12 @@ import unittest
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
import scipy.integrate
import scipy.special as osp_special import scipy.special as osp_special
import scipy.cluster as osp_cluster import scipy.cluster as osp_cluster
import jax import jax
import jax.dtypes
from jax import numpy as jnp from jax import numpy as jnp
from jax import lax from jax import lax
from jax import scipy as jsp from jax import scipy as jsp
@ -542,5 +544,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertArraysEqual(actual, nan_array, check_dtypes=False) self.assertArraysEqual(actual, nan_array, check_dtypes=False)
@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=float_dtypes + int_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
np_fun = partial(scipy.integrate.trapezoid, dx=dx, axis=axis)
jnp_fun = partial(jax.scipy.integrate.trapezoid, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
jax.dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())