1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Deprecated jax.tree_util.build_tree

We have no usages of it neither in JAX nor internally, but we still have to
go through the deprecation cycle, becuase `jax.tree_util` is public API.

PiperOrigin-RevId: 739196514
This commit is contained in:
Sergei Lebedev 2025-03-21 08:53:50 -07:00 committed by jax authors
parent 27b30190be
commit 92f5d9caa3
3 changed files with 29 additions and 9 deletions

@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* Deprecations
* {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten`
instead.
## jax 0.5.3 (Mar 19, 2025)
* New Features

@ -362,6 +362,8 @@ def tree_map(f: Callable[..., Any],
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
"""Build a treedef from a nested iterable structure
DEPRECATED: Use :func:`jax.tree.unflatten` instead.
Args:
treedef: the PyTreeDef structure to build.
xs: nested iterables matching the arity as the treedef
@ -376,13 +378,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
>>> import jax
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
>>> treedef = jax.tree.structure(tree)
Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct
the tree from new values, but ``build_tree`` takes these values in terms of
a nested rather than flat structure:
>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]])
[(10, 11), {'a': 12, 'b': 13}]
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
[(10, 11), {'a': 12, 'b': 13}]
"""

@ -48,13 +48,13 @@ from jax._src.tree_util import (
PyTreeDef as PyTreeDef,
SequenceKey as SequenceKey,
all_leaves as all_leaves,
build_tree as build_tree,
build_tree as _deprecated_build_tree,
default_registry as default_registry,
keystr as keystr,
register_dataclass as register_dataclass,
register_pytree_node_class as register_pytree_node_class,
register_pytree_node as register_pytree_node,
register_pytree_with_keys_class as register_pytree_with_keys_class,
register_dataclass as register_dataclass,
register_pytree_with_keys as register_pytree_with_keys,
register_static as register_static,
tree_all as tree_all,
@ -72,3 +72,23 @@ from jax._src.tree_util import (
treedef_is_leaf as treedef_is_leaf,
treedef_tuple as treedef_tuple,
)
_deprecations = {
# Added March 21, 2025:
"build_tree": (
(
"jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten"
" instead."
),
_deprecated_build_tree,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.tree_util import build_tree as build_tree
else:
from jax._src.deprecations import deprecation_getattr
__getattr__ = deprecation_getattr(__name__, _deprecations)
del deprecation_getattr, _deprecations
del _typing