From 26c40fadfd55f548e0d055379e6e144e2a5b17a0 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 12 Dec 2024 15:13:01 -0800 Subject: [PATCH] Add jax.tree shortcuts for .*_with_path calls, for convenience of users. PiperOrigin-RevId: 705645570 --- CHANGELOG.md | 4 ++ docs/jax.tree.rst | 3 ++ jax/_src/tree.py | 94 +++++++++++++++++++++++++++++++++++++++++ jax/_src/tree_util.py | 50 ++-------------------- jax/tree.py | 3 ++ tests/tree_util_test.py | 49 +++++++++++++++++++++ 6 files changed, 156 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 395b3a784..1c8ede6ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.38 +* Changes: + * `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added + as shortcuts of the corresponding `tree_util` functions. + * Deprecations * a number of APIs in the internal `jax.core` namespace have been deprecated, including `ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`, diff --git a/docs/jax.tree.rst b/docs/jax.tree.rst index a9d199340..e65c77c75 100644 --- a/docs/jax.tree.rst +++ b/docs/jax.tree.rst @@ -13,8 +13,11 @@ List of Functions all flatten + flatten_with_path leaves + leaves_with_path map + map_with_path reduce structure transpose diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 9719308ca..70d75a126 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef, - :func:`jax.tree.structure` """ return tree_util.tree_unflatten(treedef, leaves) + + +def flatten_with_path( + tree: Any, is_leaf: Callable[[Any], bool] | None = None +) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]: + """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, it is recommended + to be registered with ``register_pytree_with_keys``. + + Returns: + A pair which the first element is a list of key-leaf pairs, each of + which contains a leaf and its key path. The second element is a treedef + representing the structure of the flattened tree. + + Examples: + >>> import jax + >>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}]) + >>> path_vals + [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)] + >>> treedef + PyTreeDef([*, {'x': *}]) + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.map_with_path` + - :func:`jax.tree_util.register_pytree_with_keys` + """ + return tree_util.tree_flatten_with_path(tree, is_leaf) + + +def leaves_with_path( + tree: Any, is_leaf: Callable[[Any], bool] | None = None +) -> list[tuple[tree_util.KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, it is recommended to be + registered with ``register_pytree_with_keys``. + + Returns: + A list of key-leaf pairs, each of which contains a leaf and its key path. + + Examples: + >>> import jax + >>> jax.tree.leaves_with_path([1, {'x': 3}]) + [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)] + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.flatten_with_path` + - :func:`jax.tree_util.register_pytree_with_keys` + """ + return tree_util.tree_leaves_with_path(tree, is_leaf) + + +def map_with_path( + f: Callable[..., Any], + tree: Any, + *rest: Any, + is_leaf: Callable[[Any], bool] | None = None, +) -> Any: + """Maps a multi-input function over pytree key path and args to produce a new pytree. + + This is a more powerful alternative of ``tree_map`` that can take the key path + of each leaf as input argument as well. + + Args: + f: function that takes ``2 + len(rest)`` arguments, aka. the key path and + each corresponding leaves of the pytrees. + tree: a pytree to be mapped over, with each leaf's key path as the first + positional argument and the leaf itself as the second argument to ``f``. + *rest: a tuple of pytrees, each of which has the same structure as ``tree`` + or has ``tree`` as a prefix. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at + the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is + the tuple of values at corresponding nodes in ``rest``. + + Examples: + >>> import jax + >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) + [1, 3, 5] + + See Also: + - :func:`jax.tree.map` + - :func:`jax.tree.flatten_with_path` + - :func:`jax.tree.leaves_with_path` + - :func:`jax.tree_util.register_pytree_with_keys` + """ + return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 77871f3a9..8e894daf7 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -1113,16 +1113,7 @@ def register_static(cls: type[H]) -> type[H]: def tree_flatten_with_path( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: - """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. - - Args: - tree: a pytree to flatten. If it contains a custom type, it must be - registered with ``register_pytree_with_keys``. - Returns: - A pair which the first element is a list of key-leaf pairs, each of - which contains a leaf and its key path. The second element is a treedef - representing the structure of the flattened tree. - """ + """Alias of :func:`jax.tree.flatten_with_path`.""" return default_registry.flatten_with_path(tree, is_leaf) @@ -1130,18 +1121,7 @@ def tree_flatten_with_path( def tree_leaves_with_path( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: - """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. - - Args: - tree: a pytree. If it contains a custom type, it must be registered with - ``register_pytree_with_keys``. - Returns: - A list of key-leaf pairs, each of which contains a leaf and its key path. - - See Also: - - :func:`jax.tree_util.tree_leaves` - - :func:`jax.tree_util.tree_flatten_with_path` - """ + """Alias of :func:`jax.tree.leaves_with_path`.""" return tree_flatten_with_path(tree, is_leaf)[0] @@ -1157,31 +1137,7 @@ _generate_key_paths = generate_key_paths # alias for backward compat def tree_map_with_path(f: Callable[..., Any], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any: - """Maps a multi-input function over pytree key path and args to produce a new pytree. - - This is a more powerful alternative of ``tree_map`` that can take the key path - of each leaf as input argument as well. - - Args: - f: function that takes ``2 + len(rest)`` arguments, aka. the key path and - each corresponding leaves of the pytrees. - tree: a pytree to be mapped over, with each leaf's key path as the first - positional argument and the leaf itself as the second argument to ``f``. - *rest: a tuple of pytrees, each of which has the same structure as ``tree`` - or has ``tree`` as a prefix. - - Returns: - A new pytree with the same structure as ``tree`` but with the value at each - leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at - the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is - the tuple of values at corresponding nodes in ``rest``. - - See Also: - - :func:`jax.tree_util.tree_map` - - :func:`jax.tree_util.tree_flatten_with_path` - - :func:`jax.tree_util.tree_leaves_with_path` - """ - + """Alias of :func:`jax.tree.map_with_path`.""" keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) keypath_leaves = list(zip(*keypath_leaves)) all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] diff --git a/jax/tree.py b/jax/tree.py index 9b01764d6..270c34fe9 100644 --- a/jax/tree.py +++ b/jax/tree.py @@ -19,8 +19,11 @@ The :mod:`jax.tree` namespace contains aliases of utilities from :mod:`jax.tree_ from jax._src.tree import ( all as all, + flatten_with_path as flatten_with_path, flatten as flatten, + leaves_with_path as leaves_with_path, leaves as leaves, + map_with_path as map_with_path, map as map, reduce as reduce, structure as structure, diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 1b921121e..834af9c5f 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -1426,6 +1426,55 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_unflatten(treedef, leaves) ) + def test_tree_flatten_with_path(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.flatten_with_path(obj), + tree_util.tree_flatten_with_path(obj), + ) + + def test_tree_flatten_with_path_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.flatten_with_path(obj, is_leaf=is_leaf), + tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), + ) + + def test_tree_leaves_with_path(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.leaves_with_path(obj), + tree_util.tree_leaves_with_path(obj), + ) + + def test_tree_leaves_with_path_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.leaves_with_path(obj, is_leaf=is_leaf), + tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), + ) + + def test_tree_map_with_path(self): + func = lambda kp, x, y: (sum(k.idx for k in kp), x + y) + obj = [1, 2, (3, 4)] + obj2 = [5, 6, (7, 8)] + self.assertEqual( + jax.tree.map_with_path(func, obj, obj2), + tree_util.tree_map_with_path(func, obj, obj2), + ) + + def test_tree_map_with_path_is_leaf(self): + func = lambda kp, x, y: (sum(k.idx for k in kp), x + y) + obj = [1, 2, (3, 4)] + obj2 = [5, 6, (7, 8)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), + tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), + ) + class RegistrationTest(jtu.JaxTestCase):