mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Deprecated jax.tree_util.build_tree
We have no usages of it neither in JAX nor internally, but we still have to go through the deprecation cycle, becuase `jax.tree_util` is public API. PiperOrigin-RevId: 739196514
This commit is contained in:
parent
27b30190be
commit
92f5d9caa3
@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* Deprecations
|
||||
|
||||
* {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten`
|
||||
instead.
|
||||
|
||||
## jax 0.5.3 (Mar 19, 2025)
|
||||
|
||||
* New Features
|
||||
|
@ -362,6 +362,8 @@ def tree_map(f: Callable[..., Any],
|
||||
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
|
||||
"""Build a treedef from a nested iterable structure
|
||||
|
||||
DEPRECATED: Use :func:`jax.tree.unflatten` instead.
|
||||
|
||||
Args:
|
||||
treedef: the PyTreeDef structure to build.
|
||||
xs: nested iterables matching the arity as the treedef
|
||||
@ -376,13 +378,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
|
||||
>>> import jax
|
||||
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
|
||||
>>> treedef = jax.tree.structure(tree)
|
||||
|
||||
Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct
|
||||
the tree from new values, but ``build_tree`` takes these values in terms of
|
||||
a nested rather than flat structure:
|
||||
|
||||
>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]])
|
||||
[(10, 11), {'a': 12, 'b': 13}]
|
||||
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
|
||||
[(10, 11), {'a': 12, 'b': 13}]
|
||||
"""
|
||||
|
@ -48,13 +48,13 @@ from jax._src.tree_util import (
|
||||
PyTreeDef as PyTreeDef,
|
||||
SequenceKey as SequenceKey,
|
||||
all_leaves as all_leaves,
|
||||
build_tree as build_tree,
|
||||
build_tree as _deprecated_build_tree,
|
||||
default_registry as default_registry,
|
||||
keystr as keystr,
|
||||
register_dataclass as register_dataclass,
|
||||
register_pytree_node_class as register_pytree_node_class,
|
||||
register_pytree_node as register_pytree_node,
|
||||
register_pytree_with_keys_class as register_pytree_with_keys_class,
|
||||
register_dataclass as register_dataclass,
|
||||
register_pytree_with_keys as register_pytree_with_keys,
|
||||
register_static as register_static,
|
||||
tree_all as tree_all,
|
||||
@ -72,3 +72,23 @@ from jax._src.tree_util import (
|
||||
treedef_is_leaf as treedef_is_leaf,
|
||||
treedef_tuple as treedef_tuple,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added March 21, 2025:
|
||||
"build_tree": (
|
||||
(
|
||||
"jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten"
|
||||
" instead."
|
||||
),
|
||||
_deprecated_build_tree,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src.tree_util import build_tree as build_tree
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr
|
||||
__getattr__ = deprecation_getattr(__name__, _deprecations)
|
||||
del deprecation_getattr, _deprecations
|
||||
del _typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user