Remove deprecated function jax.tree_util.tree_multimap

This commit is contained in:
Jake VanderPlas 2022-07-06 09:06:16 -07:00
parent d840f54fe5
commit 108376d792
5 changed files with 3 additions and 10 deletions

View File

@ -16,6 +16,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
See {jax-issue}`#11557`.
* {mod}`jax.experimental.loops` has been removed. See {jax-issue}`#10278`
for an alternative API.
* {func}`jax.tree_util.tree_multimap` has been removed. It has been deprecated since
JAX release 0.3.5, and {func}`jax.tree_util.tree_map` is a direct replacement.
## jaxlib 0.3.16 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...main).

View File

@ -120,7 +120,6 @@ from jax._src.tree_util import (
_deprecated_tree_flatten as tree_flatten,
_deprecated_tree_leaves as tree_leaves,
_deprecated_tree_map as tree_map,
_deprecated_tree_multimap as tree_multimap,
_deprecated_tree_structure as tree_structure,
_deprecated_tree_transpose as tree_transpose,
_deprecated_tree_unflatten as tree_unflatten,

View File

@ -41,7 +41,7 @@ from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_multimap, treedef_is_leaf, treedef_children,
treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import device_array

View File

@ -200,12 +200,6 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
def tree_multimap(*args, **kwargs):
"""Deprecated alias of :func:`jax.tree_util.tree_map`"""
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
'instead as a drop-in replacement.', FutureWarning)
return tree_map(*args, **kwargs)
def build_tree(treedef, xs):
return treedef.from_iterable_tree(xs)

View File

@ -46,8 +46,6 @@ from jax._src.tree_util import (
tree_flatten as tree_flatten,
tree_leaves as tree_leaves,
tree_map as tree_map,
# TODO(jakevdp) remove tree_multimap once deprecation is complete.
tree_multimap,
tree_reduce as tree_reduce,
tree_structure as tree_structure,
tree_transpose as tree_transpose,