mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove deprecated function jax.tree_util.tree_multimap
This commit is contained in:
parent
d840f54fe5
commit
108376d792
@ -16,6 +16,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
See {jax-issue}`#11557`.
|
||||
* {mod}`jax.experimental.loops` has been removed. See {jax-issue}`#10278`
|
||||
for an alternative API.
|
||||
* {func}`jax.tree_util.tree_multimap` has been removed. It has been deprecated since
|
||||
JAX release 0.3.5, and {func}`jax.tree_util.tree_map` is a direct replacement.
|
||||
|
||||
## jaxlib 0.3.16 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...main).
|
||||
|
@ -120,7 +120,6 @@ from jax._src.tree_util import (
|
||||
_deprecated_tree_flatten as tree_flatten,
|
||||
_deprecated_tree_leaves as tree_leaves,
|
||||
_deprecated_tree_map as tree_map,
|
||||
_deprecated_tree_multimap as tree_multimap,
|
||||
_deprecated_tree_structure as tree_structure,
|
||||
_deprecated_tree_transpose as tree_transpose,
|
||||
_deprecated_tree_unflatten as tree_unflatten,
|
||||
|
@ -41,7 +41,7 @@ from jax import stages
|
||||
from jax.core import eval_jaxpr
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
tree_multimap, treedef_is_leaf, treedef_children,
|
||||
treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||
|
||||
from jax._src import device_array
|
||||
|
@ -200,12 +200,6 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||
|
||||
def tree_multimap(*args, **kwargs):
|
||||
"""Deprecated alias of :func:`jax.tree_util.tree_map`"""
|
||||
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
|
||||
'instead as a drop-in replacement.', FutureWarning)
|
||||
return tree_map(*args, **kwargs)
|
||||
|
||||
def build_tree(treedef, xs):
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
|
@ -46,8 +46,6 @@ from jax._src.tree_util import (
|
||||
tree_flatten as tree_flatten,
|
||||
tree_leaves as tree_leaves,
|
||||
tree_map as tree_map,
|
||||
# TODO(jakevdp) remove tree_multimap once deprecation is complete.
|
||||
tree_multimap,
|
||||
tree_reduce as tree_reduce,
|
||||
tree_structure as tree_structure,
|
||||
tree_transpose as tree_transpose,
|
||||
|
Loading…
x
Reference in New Issue
Block a user