mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
PiperOrigin-RevId: 705645570
This commit is contained in:
parent
ecc2673e7b
commit
26c40fadfd
@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.38
|
||||
|
||||
* Changes:
|
||||
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
|
||||
as shortcuts of the corresponding `tree_util` functions.
|
||||
|
||||
* Deprecations
|
||||
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
|
||||
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
|
||||
|
@ -13,8 +13,11 @@ List of Functions
|
||||
|
||||
all
|
||||
flatten
|
||||
flatten_with_path
|
||||
leaves
|
||||
leaves_with_path
|
||||
map
|
||||
map_with_path
|
||||
reduce
|
||||
structure
|
||||
transpose
|
||||
|
@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef,
|
||||
- :func:`jax.tree.structure`
|
||||
"""
|
||||
return tree_util.tree_unflatten(treedef, leaves)
|
||||
|
||||
|
||||
def flatten_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]:
|
||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten. If it contains a custom type, it is recommended
|
||||
to be registered with ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair which the first element is a list of key-leaf pairs, each of
|
||||
which contains a leaf and its key path. The second element is a treedef
|
||||
representing the structure of the flattened tree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
|
||||
>>> path_vals
|
||||
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||
>>> treedef
|
||||
PyTreeDef([*, {'x': *}])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
- :func:`jax.tree.map_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_flatten_with_path(tree, is_leaf)
|
||||
|
||||
|
||||
def leaves_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tuple[tree_util.KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree. If it contains a custom type, it is recommended to be
|
||||
registered with ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.leaves_with_path([1, {'x': 3}])
|
||||
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.flatten_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_leaves_with_path(tree, is_leaf)
|
||||
|
||||
|
||||
def map_with_path(
|
||||
f: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Any:
|
||||
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
||||
|
||||
This is a more powerful alternative of ``tree_map`` that can take the key path
|
||||
of each leaf as input argument as well.
|
||||
|
||||
Args:
|
||||
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
||||
each corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf's key path as the first
|
||||
positional argument and the leaf itself as the second argument to ``f``.
|
||||
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||
the tuple of values at corresponding nodes in ``rest``.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
|
||||
[1, 3, 5]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.map`
|
||||
- :func:`jax.tree.flatten_with_path`
|
||||
- :func:`jax.tree.leaves_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf)
|
||||
|
@ -1113,16 +1113,7 @@ def register_static(cls: type[H]) -> type[H]:
|
||||
def tree_flatten_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten. If it contains a custom type, it must be
|
||||
registered with ``register_pytree_with_keys``.
|
||||
Returns:
|
||||
A pair which the first element is a list of key-leaf pairs, each of
|
||||
which contains a leaf and its key path. The second element is a treedef
|
||||
representing the structure of the flattened tree.
|
||||
"""
|
||||
"""Alias of :func:`jax.tree.flatten_with_path`."""
|
||||
return default_registry.flatten_with_path(tree, is_leaf)
|
||||
|
||||
|
||||
@ -1130,18 +1121,7 @@ def tree_flatten_with_path(
|
||||
def tree_leaves_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree. If it contains a custom type, it must be registered with
|
||||
``register_pytree_with_keys``.
|
||||
Returns:
|
||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.tree_leaves`
|
||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
||||
"""
|
||||
"""Alias of :func:`jax.tree.leaves_with_path`."""
|
||||
return tree_flatten_with_path(tree, is_leaf)[0]
|
||||
|
||||
|
||||
@ -1157,31 +1137,7 @@ _generate_key_paths = generate_key_paths # alias for backward compat
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
tree: Any, *rest: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
||||
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
||||
|
||||
This is a more powerful alternative of ``tree_map`` that can take the key path
|
||||
of each leaf as input argument as well.
|
||||
|
||||
Args:
|
||||
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
||||
each corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf's key path as the first
|
||||
positional argument and the leaf itself as the second argument to ``f``.
|
||||
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||
the tuple of values at corresponding nodes in ``rest``.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.tree_map`
|
||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
||||
- :func:`jax.tree_util.tree_leaves_with_path`
|
||||
"""
|
||||
|
||||
"""Alias of :func:`jax.tree.map_with_path`."""
|
||||
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
|
||||
keypath_leaves = list(zip(*keypath_leaves))
|
||||
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
|
||||
|
@ -19,8 +19,11 @@ The :mod:`jax.tree` namespace contains aliases of utilities from :mod:`jax.tree_
|
||||
|
||||
from jax._src.tree import (
|
||||
all as all,
|
||||
flatten_with_path as flatten_with_path,
|
||||
flatten as flatten,
|
||||
leaves_with_path as leaves_with_path,
|
||||
leaves as leaves,
|
||||
map_with_path as map_with_path,
|
||||
map as map,
|
||||
reduce as reduce,
|
||||
structure as structure,
|
||||
|
@ -1426,6 +1426,55 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_unflatten(treedef, leaves)
|
||||
)
|
||||
|
||||
def test_tree_flatten_with_path(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
self.assertEqual(
|
||||
jax.tree.flatten_with_path(obj),
|
||||
tree_util.tree_flatten_with_path(obj),
|
||||
)
|
||||
|
||||
def test_tree_flatten_with_path_is_leaf(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.flatten_with_path(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_leaves_with_path(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
self.assertEqual(
|
||||
jax.tree.leaves_with_path(obj),
|
||||
tree_util.tree_leaves_with_path(obj),
|
||||
)
|
||||
|
||||
def test_tree_leaves_with_path_is_leaf(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.leaves_with_path(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_map_with_path(self):
|
||||
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
|
||||
obj = [1, 2, (3, 4)]
|
||||
obj2 = [5, 6, (7, 8)]
|
||||
self.assertEqual(
|
||||
jax.tree.map_with_path(func, obj, obj2),
|
||||
tree_util.tree_map_with_path(func, obj, obj2),
|
||||
)
|
||||
|
||||
def test_tree_map_with_path_is_leaf(self):
|
||||
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
|
||||
obj = [1, 2, (3, 4)]
|
||||
obj2 = [5, 6, (7, 8)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf),
|
||||
tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
|
||||
class RegistrationTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user