Add jax.tree shortcuts for .*_with_path calls, for convenience of users.

PiperOrigin-RevId: 705645570
This commit is contained in:
Ivy Zheng 2024-12-12 15:13:01 -08:00 committed by jax authors
parent ecc2673e7b
commit 26c40fadfd
6 changed files with 156 additions and 47 deletions

View File

@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.38
* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,

View File

@ -13,8 +13,11 @@ List of Functions
all
flatten
flatten_with_path
leaves
leaves_with_path
map
map_with_path
reduce
structure
transpose

View File

@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef,
- :func:`jax.tree.structure`
"""
return tree_util.tree_unflatten(treedef, leaves)
def flatten_with_path(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]:
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
Args:
tree: a pytree to flatten. If it contains a custom type, it is recommended
to be registered with ``register_pytree_with_keys``.
Returns:
A pair which the first element is a list of key-leaf pairs, each of
which contains a leaf and its key path. The second element is a treedef
representing the structure of the flattened tree.
Examples:
>>> import jax
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
>>> path_vals
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
>>> treedef
PyTreeDef([*, {'x': *}])
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.map_with_path`
- :func:`jax.tree_util.register_pytree_with_keys`
"""
return tree_util.tree_flatten_with_path(tree, is_leaf)
def leaves_with_path(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> list[tuple[tree_util.KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
Args:
tree: a pytree. If it contains a custom type, it is recommended to be
registered with ``register_pytree_with_keys``.
Returns:
A list of key-leaf pairs, each of which contains a leaf and its key path.
Examples:
>>> import jax
>>> jax.tree.leaves_with_path([1, {'x': 3}])
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.flatten_with_path`
- :func:`jax.tree_util.register_pytree_with_keys`
"""
return tree_util.tree_leaves_with_path(tree, is_leaf)
def map_with_path(
f: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Any:
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of ``tree_map`` that can take the key path
of each leaf as input argument as well.
Args:
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
each corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf's key path as the first
positional argument and the leaf itself as the second argument to ``f``.
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
the tuple of values at corresponding nodes in ``rest``.
Examples:
>>> import jax
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
[1, 3, 5]
See Also:
- :func:`jax.tree.map`
- :func:`jax.tree.flatten_with_path`
- :func:`jax.tree.leaves_with_path`
- :func:`jax.tree_util.register_pytree_with_keys`
"""
return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf)

View File

@ -1113,16 +1113,7 @@ def register_static(cls: type[H]) -> type[H]:
def tree_flatten_with_path(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
Args:
tree: a pytree to flatten. If it contains a custom type, it must be
registered with ``register_pytree_with_keys``.
Returns:
A pair which the first element is a list of key-leaf pairs, each of
which contains a leaf and its key path. The second element is a treedef
representing the structure of the flattened tree.
"""
"""Alias of :func:`jax.tree.flatten_with_path`."""
return default_registry.flatten_with_path(tree, is_leaf)
@ -1130,18 +1121,7 @@ def tree_flatten_with_path(
def tree_leaves_with_path(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> list[tuple[KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
Args:
tree: a pytree. If it contains a custom type, it must be registered with
``register_pytree_with_keys``.
Returns:
A list of key-leaf pairs, each of which contains a leaf and its key path.
See Also:
- :func:`jax.tree_util.tree_leaves`
- :func:`jax.tree_util.tree_flatten_with_path`
"""
"""Alias of :func:`jax.tree.leaves_with_path`."""
return tree_flatten_with_path(tree, is_leaf)[0]
@ -1157,31 +1137,7 @@ _generate_key_paths = generate_key_paths # alias for backward compat
def tree_map_with_path(f: Callable[..., Any],
tree: Any, *rest: Any,
is_leaf: Callable[[Any], bool] | None = None) -> Any:
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of ``tree_map`` that can take the key path
of each leaf as input argument as well.
Args:
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
each corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf's key path as the first
positional argument and the leaf itself as the second argument to ``f``.
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
the tuple of values at corresponding nodes in ``rest``.
See Also:
- :func:`jax.tree_util.tree_map`
- :func:`jax.tree_util.tree_flatten_with_path`
- :func:`jax.tree_util.tree_leaves_with_path`
"""
"""Alias of :func:`jax.tree.map_with_path`."""
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
keypath_leaves = list(zip(*keypath_leaves))
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]

View File

@ -19,8 +19,11 @@ The :mod:`jax.tree` namespace contains aliases of utilities from :mod:`jax.tree_
from jax._src.tree import (
all as all,
flatten_with_path as flatten_with_path,
flatten as flatten,
leaves_with_path as leaves_with_path,
leaves as leaves,
map_with_path as map_with_path,
map as map,
reduce as reduce,
structure as structure,

View File

@ -1426,6 +1426,55 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_unflatten(treedef, leaves)
)
def test_tree_flatten_with_path(self):
obj = [1, 2, (3, 4)]
self.assertEqual(
jax.tree.flatten_with_path(obj),
tree_util.tree_flatten_with_path(obj),
)
def test_tree_flatten_with_path_is_leaf(self):
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.flatten_with_path(obj, is_leaf=is_leaf),
tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf),
)
def test_tree_leaves_with_path(self):
obj = [1, 2, (3, 4)]
self.assertEqual(
jax.tree.leaves_with_path(obj),
tree_util.tree_leaves_with_path(obj),
)
def test_tree_leaves_with_path_is_leaf(self):
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.leaves_with_path(obj, is_leaf=is_leaf),
tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf),
)
def test_tree_map_with_path(self):
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
obj = [1, 2, (3, 4)]
obj2 = [5, 6, (7, 8)]
self.assertEqual(
jax.tree.map_with_path(func, obj, obj2),
tree_util.tree_map_with_path(func, obj, obj2),
)
def test_tree_map_with_path_is_leaf(self):
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
obj = [1, 2, (3, 4)]
obj2 = [5, 6, (7, 8)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf),
tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf),
)
class RegistrationTest(jtu.JaxTestCase):