mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
PiperOrigin-RevId: 705645570
This commit is contained in:
parent
ecc2673e7b
commit
26c40fadfd
@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
|
|
||||||
## jax 0.4.38
|
## jax 0.4.38
|
||||||
|
|
||||||
|
* Changes:
|
||||||
|
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
|
||||||
|
as shortcuts of the corresponding `tree_util` functions.
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
|
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
|
||||||
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
|
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
|
||||||
|
@ -13,8 +13,11 @@ List of Functions
|
|||||||
|
|
||||||
all
|
all
|
||||||
flatten
|
flatten
|
||||||
|
flatten_with_path
|
||||||
leaves
|
leaves
|
||||||
|
leaves_with_path
|
||||||
map
|
map
|
||||||
|
map_with_path
|
||||||
reduce
|
reduce
|
||||||
structure
|
structure
|
||||||
transpose
|
transpose
|
||||||
|
@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef,
|
|||||||
- :func:`jax.tree.structure`
|
- :func:`jax.tree.structure`
|
||||||
"""
|
"""
|
||||||
return tree_util.tree_unflatten(treedef, leaves)
|
return tree_util.tree_unflatten(treedef, leaves)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_with_path(
|
||||||
|
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||||
|
) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]:
|
||||||
|
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: a pytree to flatten. If it contains a custom type, it is recommended
|
||||||
|
to be registered with ``register_pytree_with_keys``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair which the first element is a list of key-leaf pairs, each of
|
||||||
|
which contains a leaf and its key path. The second element is a treedef
|
||||||
|
representing the structure of the flattened tree.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import jax
|
||||||
|
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
|
||||||
|
>>> path_vals
|
||||||
|
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||||
|
>>> treedef
|
||||||
|
PyTreeDef([*, {'x': *}])
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- :func:`jax.tree.flatten`
|
||||||
|
- :func:`jax.tree.map_with_path`
|
||||||
|
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||||
|
"""
|
||||||
|
return tree_util.tree_flatten_with_path(tree, is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
def leaves_with_path(
|
||||||
|
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||||
|
) -> list[tuple[tree_util.KeyPath, Any]]:
|
||||||
|
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: a pytree. If it contains a custom type, it is recommended to be
|
||||||
|
registered with ``register_pytree_with_keys``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import jax
|
||||||
|
>>> jax.tree.leaves_with_path([1, {'x': 3}])
|
||||||
|
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- :func:`jax.tree.leaves`
|
||||||
|
- :func:`jax.tree.flatten_with_path`
|
||||||
|
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||||
|
"""
|
||||||
|
return tree_util.tree_leaves_with_path(tree, is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
def map_with_path(
|
||||||
|
f: Callable[..., Any],
|
||||||
|
tree: Any,
|
||||||
|
*rest: Any,
|
||||||
|
is_leaf: Callable[[Any], bool] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
||||||
|
|
||||||
|
This is a more powerful alternative of ``tree_map`` that can take the key path
|
||||||
|
of each leaf as input argument as well.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
||||||
|
each corresponding leaves of the pytrees.
|
||||||
|
tree: a pytree to be mapped over, with each leaf's key path as the first
|
||||||
|
positional argument and the leaf itself as the second argument to ``f``.
|
||||||
|
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||||
|
or has ``tree`` as a prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new pytree with the same structure as ``tree`` but with the value at each
|
||||||
|
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||||
|
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||||
|
the tuple of values at corresponding nodes in ``rest``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import jax
|
||||||
|
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
|
||||||
|
[1, 3, 5]
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- :func:`jax.tree.map`
|
||||||
|
- :func:`jax.tree.flatten_with_path`
|
||||||
|
- :func:`jax.tree.leaves_with_path`
|
||||||
|
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||||
|
"""
|
||||||
|
return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf)
|
||||||
|
@ -1113,16 +1113,7 @@ def register_static(cls: type[H]) -> type[H]:
|
|||||||
def tree_flatten_with_path(
|
def tree_flatten_with_path(
|
||||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||||
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
||||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
"""Alias of :func:`jax.tree.flatten_with_path`."""
|
||||||
|
|
||||||
Args:
|
|
||||||
tree: a pytree to flatten. If it contains a custom type, it must be
|
|
||||||
registered with ``register_pytree_with_keys``.
|
|
||||||
Returns:
|
|
||||||
A pair which the first element is a list of key-leaf pairs, each of
|
|
||||||
which contains a leaf and its key path. The second element is a treedef
|
|
||||||
representing the structure of the flattened tree.
|
|
||||||
"""
|
|
||||||
return default_registry.flatten_with_path(tree, is_leaf)
|
return default_registry.flatten_with_path(tree, is_leaf)
|
||||||
|
|
||||||
|
|
||||||
@ -1130,18 +1121,7 @@ def tree_flatten_with_path(
|
|||||||
def tree_leaves_with_path(
|
def tree_leaves_with_path(
|
||||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||||
) -> list[tuple[KeyPath, Any]]:
|
) -> list[tuple[KeyPath, Any]]:
|
||||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
"""Alias of :func:`jax.tree.leaves_with_path`."""
|
||||||
|
|
||||||
Args:
|
|
||||||
tree: a pytree. If it contains a custom type, it must be registered with
|
|
||||||
``register_pytree_with_keys``.
|
|
||||||
Returns:
|
|
||||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
|
||||||
|
|
||||||
See Also:
|
|
||||||
- :func:`jax.tree_util.tree_leaves`
|
|
||||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
|
||||||
"""
|
|
||||||
return tree_flatten_with_path(tree, is_leaf)[0]
|
return tree_flatten_with_path(tree, is_leaf)[0]
|
||||||
|
|
||||||
|
|
||||||
@ -1157,31 +1137,7 @@ _generate_key_paths = generate_key_paths # alias for backward compat
|
|||||||
def tree_map_with_path(f: Callable[..., Any],
|
def tree_map_with_path(f: Callable[..., Any],
|
||||||
tree: Any, *rest: Any,
|
tree: Any, *rest: Any,
|
||||||
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
||||||
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
"""Alias of :func:`jax.tree.map_with_path`."""
|
||||||
|
|
||||||
This is a more powerful alternative of ``tree_map`` that can take the key path
|
|
||||||
of each leaf as input argument as well.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
|
||||||
each corresponding leaves of the pytrees.
|
|
||||||
tree: a pytree to be mapped over, with each leaf's key path as the first
|
|
||||||
positional argument and the leaf itself as the second argument to ``f``.
|
|
||||||
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
|
||||||
or has ``tree`` as a prefix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new pytree with the same structure as ``tree`` but with the value at each
|
|
||||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
|
||||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
|
||||||
the tuple of values at corresponding nodes in ``rest``.
|
|
||||||
|
|
||||||
See Also:
|
|
||||||
- :func:`jax.tree_util.tree_map`
|
|
||||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
|
||||||
- :func:`jax.tree_util.tree_leaves_with_path`
|
|
||||||
"""
|
|
||||||
|
|
||||||
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
|
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
|
||||||
keypath_leaves = list(zip(*keypath_leaves))
|
keypath_leaves = list(zip(*keypath_leaves))
|
||||||
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
|
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
|
||||||
|
@ -19,8 +19,11 @@ The :mod:`jax.tree` namespace contains aliases of utilities from :mod:`jax.tree_
|
|||||||
|
|
||||||
from jax._src.tree import (
|
from jax._src.tree import (
|
||||||
all as all,
|
all as all,
|
||||||
|
flatten_with_path as flatten_with_path,
|
||||||
flatten as flatten,
|
flatten as flatten,
|
||||||
|
leaves_with_path as leaves_with_path,
|
||||||
leaves as leaves,
|
leaves as leaves,
|
||||||
|
map_with_path as map_with_path,
|
||||||
map as map,
|
map as map,
|
||||||
reduce as reduce,
|
reduce as reduce,
|
||||||
structure as structure,
|
structure as structure,
|
||||||
|
@ -1426,6 +1426,55 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_unflatten(treedef, leaves)
|
tree_util.tree_unflatten(treedef, leaves)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_flatten_with_path(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.flatten_with_path(obj),
|
||||||
|
tree_util.tree_flatten_with_path(obj),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tree_flatten_with_path_is_leaf(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.flatten_with_path(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tree_leaves_with_path(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.leaves_with_path(obj),
|
||||||
|
tree_util.tree_leaves_with_path(obj),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tree_leaves_with_path_is_leaf(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.leaves_with_path(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tree_map_with_path(self):
|
||||||
|
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
obj2 = [5, 6, (7, 8)]
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.map_with_path(func, obj, obj2),
|
||||||
|
tree_util.tree_map_with_path(func, obj, obj2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tree_map_with_path_is_leaf(self):
|
||||||
|
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
obj2 = [5, 6, (7, 8)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationTest(jtu.JaxTestCase):
|
class RegistrationTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user