From 245a13a329b9e258e363be8584a28cb2999037f7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 6 Jan 2025 09:31:15 -0800 Subject: [PATCH] Deprecate scipy.special.lpmn & lpmn_values --- CHANGELOG.md | 3 +++ jax/scipy/special.py | 26 ++++++++++++++++++++++++-- tests/lax_scipy_test.py | 6 ++---- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb9404268..3ffdafbb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings` are now deprecated, having been replaced by symbols of the same name in {mod}`jax.core`. + * {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values` + are deprecated, following their deprecation in SciPy v1.15.0. There are + no plans to replace these deprecated functions with new APIs. * Deletions * `jax_enable_memories` flag has been deleted and the behavior of that flag diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 431617d36..c83ead11a 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -46,8 +46,8 @@ from jax._src.scipy.special import ( log_softmax as log_softmax, logit as logit, logsumexp as logsumexp, - lpmn as lpmn, - lpmn_values as lpmn_values, + lpmn as _deprecated_lpmn, + lpmn_values as _deprecated_lpmn_values, multigammaln as multigammaln, ndtr as ndtr, ndtri as ndtri, @@ -65,3 +65,25 @@ from jax._src.scipy.special import ( from jax._src.third_party.scipy.special import ( fresnel as fresnel, ) + +_deprecations = { + # Added Jan 3 2024 + "lpmn": ( + "jax.scipy.special.lpmn is deprecated; no replacement is planned.", + _deprecated_lpmn, + ), + "lpmn_values": ( + "jax.scipy.special.lpmn_values is deprecated; no replacement is planned.", + _deprecated_lpmn_values, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + lpmn = _deprecated_lpmn + lpmn_values = _deprecated_lpmn_values +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index c83ea4020..c30e8da65 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -332,8 +332,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): shape=[(5,), (10,)], dtype=float_dtypes, ) - @jtu.ignore_warning(category=DeprecationWarning, - message="`scipy.special.lpmn` is deprecated") + @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testLpmn(self, l_max, shape, dtype): if jtu.is_device_tpu(6, "e"): self.skipTest("TODO(b/364258243): fails on TPU v6e") @@ -356,8 +355,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): shape=[(2,), (3,), (4,), (64,)], dtype=float_dtypes, ) - @jtu.ignore_warning(category=DeprecationWarning, - message="`scipy.special.lpmn` is deprecated") + @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testNormalizedLpmnValues(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)]