mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

These types lie on the APIs in `jax.tree_util`, so it makes sense to export them. PiperOrigin-RevId: 657987755
49 lines
801 B
ReStructuredText
49 lines
801 B
ReStructuredText
``jax.tree_util`` module
|
|
========================
|
|
|
|
.. currentmodule:: jax.tree_util
|
|
|
|
.. automodule:: jax.tree_util
|
|
|
|
List of Functions
|
|
-----------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
Partial
|
|
all_leaves
|
|
build_tree
|
|
register_dataclass
|
|
register_pytree_node
|
|
register_pytree_node_class
|
|
register_pytree_with_keys
|
|
register_pytree_with_keys_class
|
|
register_static
|
|
tree_flatten_with_path
|
|
tree_leaves_with_path
|
|
tree_map_with_path
|
|
treedef_children
|
|
treedef_is_leaf
|
|
treedef_tuple
|
|
KeyEntry
|
|
KeyPath
|
|
keystr
|
|
|
|
Legacy APIs
|
|
-----------
|
|
These APIs are now accessed via :mod:`jax.tree`.
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
tree_all
|
|
tree_flatten
|
|
tree_leaves
|
|
tree_map
|
|
tree_reduce
|
|
tree_structure
|
|
tree_transpose
|
|
tree_unflatten
|
|
|