finalize deprecation of jax.clear_backends

This commit is contained in:
Jake VanderPlas 2024-11-12 12:26:34 -08:00
parent a1eb5ceade
commit f401c97967
2 changed files with 4 additions and 5 deletions

View File

@ -43,6 +43,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for

View File

@ -83,7 +83,6 @@ from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.custom_derivatives import custom_gradient as custom_gradient
@ -218,16 +217,15 @@ _deprecations = {
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
# Added Mar 18, 2024
# Finalized Nov 12 2024; remove after Feb 12 2025
"clear_backends": (
"jax.clear_backends is deprecated.",
_deprecated_clear_backends
"jax.clear_backends was removed in JAX v0.4.36",
None
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.api import clear_backends as clear_backends
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves