mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate scipy.special.lpmn & lpmn_values
This commit is contained in:
parent
6c87bf389f
commit
245a13a329
@ -30,6 +30,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
in {mod}`jax.core`.
|
||||
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
|
||||
are deprecated, following their deprecation in SciPy v1.15.0. There are
|
||||
no plans to replace these deprecated functions with new APIs.
|
||||
|
||||
* Deletions
|
||||
* `jax_enable_memories` flag has been deleted and the behavior of that flag
|
||||
|
@ -46,8 +46,8 @@ from jax._src.scipy.special import (
|
||||
log_softmax as log_softmax,
|
||||
logit as logit,
|
||||
logsumexp as logsumexp,
|
||||
lpmn as lpmn,
|
||||
lpmn_values as lpmn_values,
|
||||
lpmn as _deprecated_lpmn,
|
||||
lpmn_values as _deprecated_lpmn_values,
|
||||
multigammaln as multigammaln,
|
||||
ndtr as ndtr,
|
||||
ndtri as ndtri,
|
||||
@ -65,3 +65,25 @@ from jax._src.scipy.special import (
|
||||
from jax._src.third_party.scipy.special import (
|
||||
fresnel as fresnel,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added Jan 3 2024
|
||||
"lpmn": (
|
||||
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
|
||||
_deprecated_lpmn,
|
||||
),
|
||||
"lpmn_values": (
|
||||
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
|
||||
_deprecated_lpmn_values,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
lpmn = _deprecated_lpmn
|
||||
lpmn_values = _deprecated_lpmn_values
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
@ -332,8 +332,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
shape=[(5,), (10,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.lpmn` is deprecated")
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
|
||||
def testLpmn(self, l_max, shape, dtype):
|
||||
if jtu.is_device_tpu(6, "e"):
|
||||
self.skipTest("TODO(b/364258243): fails on TPU v6e")
|
||||
@ -356,8 +355,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
shape=[(2,), (3,), (4,), (64,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.lpmn` is deprecated")
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
|
||||
def testNormalizedLpmnValues(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user