add a tree_util.py module-level docstring

This commit is contained in:
Matthew Johnson 2019-05-02 08:02:01 -07:00
parent 21118d0dff
commit 87a150e567
2 changed files with 25 additions and 4 deletions

View File

@ -13,11 +13,13 @@
# limitations under the License.
"""
User-facing transformations.
JAX user-facing transformations and utilities.
These mostly wrap internal transformations, providing convenience flags to
control behavior and handling Python containers (tuples/lists/dicts) of
arguments and outputs.
The transformations here mostly wrap internal transformations, providing
convenience flags to control behavior and handling Python containers of
arguments and outputs. The Python containers handled are pytrees (see
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
arrays or JaxTuples.
"""
from __future__ import absolute_import

View File

@ -12,6 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with tree-like container data structures.
The code here is independent of JAX. The only dependence is on jax.util, which
itself has no JAX-specific code.
This module provides a small set of utility functions for working with tree-like
data structures, such as nested tuples, lists, and dicts. We call these
structures pytrees. They are trees in that they are defined recursively (any
non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and
can be operated on recursively (object identity equivalence is not preserved by
mapping operations, and the structures cannot contain reference cycles).
The set of Python types that are considered pytree nodes (e.g. that can be
mapped over, rather than treated as leaves) is extensible. There is a single
module-level registry of types, and class hierarchy is ignored. By registering a
new pytree node type, that type in effect becomes transparent to the utility
functions in this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function