Deprecate scipy.special.lpmn & lpmn_values

This commit is contained in:
Jake VanderPlas 2025-01-06 09:31:15 -08:00
parent 6c87bf389f
commit 245a13a329
3 changed files with 29 additions and 6 deletions

View File

@ -30,6 +30,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag

View File

@ -46,8 +46,8 @@ from jax._src.scipy.special import (
log_softmax as log_softmax,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
lpmn as _deprecated_lpmn,
lpmn_values as _deprecated_lpmn_values,
multigammaln as multigammaln,
ndtr as ndtr,
ndtri as ndtri,
@ -65,3 +65,25 @@ from jax._src.scipy.special import (
from jax._src.third_party.scipy.special import (
fresnel as fresnel,
)
_deprecations = {
# Added Jan 3 2024
"lpmn": (
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
_deprecated_lpmn,
),
"lpmn_values": (
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
_deprecated_lpmn_values,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
lpmn = _deprecated_lpmn
lpmn_values = _deprecated_lpmn_values
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing

View File

@ -332,8 +332,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
shape=[(5,), (10,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.lpmn` is deprecated")
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testLpmn(self, l_max, shape, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
@ -356,8 +355,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
shape=[(2,), (3,), (4,), (64,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.lpmn` is deprecated")
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]