mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3 PiperOrigin-RevId: 611230367
This commit is contained in:
parent
e079bb9938
commit
236275ebe1
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.26
|
||||
|
||||
* Deprecations
|
||||
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
||||
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
||||
|
||||
## jaxlib 0.4.26
|
||||
|
||||
## jax 0.4.25 (Feb 26, 2024)
|
||||
|
@ -136,7 +136,7 @@ from jax._src.array import (
|
||||
)
|
||||
|
||||
from jax._src.tree_util import (
|
||||
tree_map as tree_map,
|
||||
tree_map as _deprecated_tree_map,
|
||||
treedef_is_leaf as _deprecated_treedef_is_leaf,
|
||||
tree_flatten as _deprecated_tree_flatten,
|
||||
tree_leaves as _deprecated_tree_leaves,
|
||||
@ -212,6 +212,12 @@ _deprecations = {
|
||||
"or jax.tree_util.tree_unflatten (any JAX version).",
|
||||
_deprecated_tree_unflatten
|
||||
),
|
||||
# Added Feb 28, 2024
|
||||
"tree_map": (
|
||||
"jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) "
|
||||
"or jax.tree_util.tree_map (any JAX version).",
|
||||
_deprecated_tree_map
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
@ -219,6 +225,7 @@ if _typing.TYPE_CHECKING:
|
||||
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
||||
from jax._src.tree_util import tree_flatten as tree_flatten
|
||||
from jax._src.tree_util import tree_leaves as tree_leaves
|
||||
from jax._src.tree_util import tree_map as tree_map
|
||||
from jax._src.tree_util import tree_structure as tree_structure
|
||||
from jax._src.tree_util import tree_transpose as tree_transpose
|
||||
from jax._src.tree_util import tree_unflatten as tree_unflatten
|
||||
|
Loading…
x
Reference in New Issue
Block a user