2024-02-12 13:07:59 -08:00
|
|
|
# Copyright 2024 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import Any, Callable, Iterable, TypeVar, overload
|
|
|
|
|
|
|
|
from jax._src import tree_util
|
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
2024-06-10 09:46:15 -07:00
|
|
|
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_all`."""
|
2024-06-10 09:46:15 -07:00
|
|
|
return tree_util.tree_all(tree, is_leaf=is_leaf)
|
2024-02-12 13:07:59 -08:00
|
|
|
|
|
|
|
|
|
|
|
def flatten(tree: Any,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
|
|
) -> tuple[list[tree_util.Leaf], tree_util.PyTreeDef]:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_flatten`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_flatten(tree, is_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
def leaves(tree: Any,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
|
|
) -> list[tree_util.Leaf]:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_leaves`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_leaves(tree, is_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
def map(f: Callable[..., Any],
|
|
|
|
tree: Any,
|
|
|
|
*rest: Any,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_map`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def reduce(function: Callable[[T, Any], T],
|
|
|
|
tree: Any,
|
|
|
|
*,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None) -> T:
|
|
|
|
...
|
|
|
|
@overload
|
|
|
|
def reduce(function: Callable[[T, Any], T],
|
|
|
|
tree: Any,
|
|
|
|
initializer: T,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None) -> T:
|
|
|
|
...
|
|
|
|
def reduce(function: Callable[[T, Any], T],
|
|
|
|
tree: Any,
|
|
|
|
initializer: Any = tree_util.no_initializer,
|
|
|
|
is_leaf: Callable[[Any], bool] | None = None) -> T:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_reduce`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
def structure(tree: Any,
|
|
|
|
is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_structure`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_structure(tree, is_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
def transpose(outer_treedef: tree_util.PyTreeDef,
|
|
|
|
inner_treedef: tree_util.PyTreeDef,
|
|
|
|
pytree_to_transpose: Any) -> Any:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_transpose`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
|
|
|
|
|
|
|
|
|
|
|
|
def unflatten(treedef: tree_util.PyTreeDef,
|
|
|
|
leaves: Iterable[tree_util.Leaf]) -> Any:
|
2024-06-04 11:56:35 -07:00
|
|
|
"""Alias of :func:`jax.tree_util.tree_unflatten`."""
|
2024-02-12 13:07:59 -08:00
|
|
|
return tree_util.tree_unflatten(treedef, leaves)
|