2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-05-02 08:02:01 -07:00
|
|
|
"""Utilities for working with tree-like container data structures.
|
|
|
|
|
|
|
|
This module provides a small set of utility functions for working with tree-like
|
|
|
|
data structures, such as nested tuples, lists, and dicts. We call these
|
|
|
|
structures pytrees. They are trees in that they are defined recursively (any
|
|
|
|
non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and
|
|
|
|
can be operated on recursively (object identity equivalence is not preserved by
|
|
|
|
mapping operations, and the structures cannot contain reference cycles).
|
|
|
|
|
|
|
|
The set of Python types that are considered pytree nodes (e.g. that can be
|
|
|
|
mapped over, rather than treated as leaves) is extensible. There is a single
|
|
|
|
module-level registry of types, and class hierarchy is ignored. By registering a
|
|
|
|
new pytree node type, that type in effect becomes transparent to the utility
|
|
|
|
functions in this file.
|
2019-08-23 16:54:59 -07:00
|
|
|
|
|
|
|
The primary purpose of this module is to enable the interoperability between
|
2019-08-23 19:32:45 -07:00
|
|
|
user defined data structures and JAX transformations (e.g. `jit`). This is not
|
|
|
|
meant to be a general purpose tree-like data structure handling library.
|
2019-10-27 10:29:33 +01:00
|
|
|
|
2020-07-06 09:04:02 +03:00
|
|
|
See the `JAX pytrees note <pytrees.html>`_
|
2019-10-27 10:29:33 +01:00
|
|
|
for examples.
|
2019-05-02 08:02:01 -07:00
|
|
|
"""
|
|
|
|
|
2021-03-24 12:00:12 -07:00
|
|
|
from jax._src.tree_util import (
|
2021-08-30 14:35:22 -07:00
|
|
|
Partial as Partial,
|
|
|
|
PyTreeDef as PyTreeDef,
|
|
|
|
all_leaves as all_leaves,
|
|
|
|
build_tree as build_tree,
|
|
|
|
register_pytree_node as register_pytree_node,
|
|
|
|
register_pytree_node_class as register_pytree_node_class,
|
|
|
|
tree_all as tree_all,
|
|
|
|
tree_flatten as tree_flatten,
|
|
|
|
tree_leaves as tree_leaves,
|
|
|
|
tree_map as tree_map,
|
|
|
|
tree_reduce as tree_reduce,
|
|
|
|
tree_structure as tree_structure,
|
|
|
|
tree_transpose as tree_transpose,
|
|
|
|
tree_unflatten as tree_unflatten,
|
|
|
|
treedef_children as treedef_children,
|
|
|
|
treedef_is_leaf as treedef_is_leaf,
|
|
|
|
treedef_tuple as treedef_tuple,
|
2022-02-08 12:45:38 -08:00
|
|
|
register_keypaths as register_keypaths,
|
|
|
|
AttributeKeyPathEntry as AttributeKeyPathEntry,
|
|
|
|
GetitemKeyPathEntry as GetitemKeyPathEntry,
|
2019-07-09 12:05:59 -07:00
|
|
|
)
|