Deprecate jax.clear_backends.

`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
This commit is contained in:
Yue Sheng 2024-03-18 14:22:35 -07:00 committed by jax authors
parent 154403c03d
commit 147c363ea6
2 changed files with 13 additions and 1 deletions

View File

@ -11,6 +11,12 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.clear_backends` is deprecated as it does not necessarily do what
its name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}`jax.clear_caches` if you only want to clean up compilation caches.
For backward compatibility or you really need to switch/reinitialize the
default backend, use {func}`jax.extend.backend.clear_backends`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.

View File

@ -81,7 +81,7 @@ from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.custom_derivatives import custom_gradient as custom_gradient
@ -218,10 +218,16 @@ _deprecations = {
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
# Added Mar 18, 2024
"clear_backends": (
"jax.clear_backends is deprecated.",
_deprecated_clear_backends
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.api import clear_backends as clear_backends
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves