tree_all: add support for is_leaf

This commit is contained in:
Jake VanderPlas 2024-06-10 09:46:15 -07:00
parent 0739d520b1
commit 814b32a44b
4 changed files with 59 additions and 4 deletions

View File

@ -18,6 +18,7 @@ Remember to align the itemized text with the first line of an item within a list
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use

View File

@ -20,9 +20,9 @@ from jax._src import tree_util
T = TypeVar("T")
def all(tree: Any) -> bool:
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
"""Alias of :func:`jax.tree_util.tree_all`."""
return tree_util.tree_all(tree)
return tree_util.tree_all(tree, is_leaf=is_leaf)
def flatten(tree: Any,

View File

@ -623,11 +623,15 @@ def tree_reduce(function: Callable[[T, Any], T],
@export
def tree_all(tree: Any) -> bool:
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
"""Call all() over the leaves of a tree.
Args:
tree: the pytree to evaluate
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
result: boolean True or False
@ -643,7 +647,7 @@ def tree_all(tree: Any) -> bool:
- :func:`jax.tree_util.tree_reduce`
- :func:`jax.tree_util.tree_leaves`
"""
return all(tree_leaves(tree))
return all(tree_leaves(tree, is_leaf=is_leaf))
register_pytree_node(

View File

@ -1153,6 +1153,14 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_all(obj),
)
def test_tree_all_is_leaf(self):
obj = [True, True, (True, False)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.all(obj, is_leaf=is_leaf),
tree_util.tree_all(obj, is_leaf=is_leaf),
)
def test_tree_flatten(self):
obj = [1, 2, (3, 4)]
self.assertEqual(
@ -1160,6 +1168,14 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_flatten(obj),
)
def test_tree_flatten_is_leaf(self):
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.flatten(obj, is_leaf=is_leaf),
tree_util.tree_flatten(obj, is_leaf=is_leaf),
)
def test_tree_leaves(self):
obj = [1, 2, (3, 4)]
self.assertEqual(
@ -1167,6 +1183,14 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_leaves(obj),
)
def test_tree_leaves_is_leaf(self):
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.leaves(obj, is_leaf=is_leaf),
tree_util.tree_leaves(obj, is_leaf=is_leaf),
)
def test_tree_map(self):
func = lambda x: x * 2
obj = [1, 2, (3, 4)]
@ -1175,6 +1199,15 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_map(func, obj),
)
def test_tree_map_is_leaf(self):
func = lambda x: x * 2
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.map(func, obj, is_leaf=is_leaf),
tree_util.tree_map(func, obj, is_leaf=is_leaf),
)
def test_tree_reduce(self):
func = lambda a, b: a + b
obj = [1, 2, (3, 4)]
@ -1183,6 +1216,15 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_reduce(func, obj),
)
def test_tree_reduce_is_leaf(self):
func = lambda a, b: a + b
obj = [(1, 2), (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.reduce(func, obj, is_leaf=is_leaf),
tree_util.tree_reduce(func, obj, is_leaf=is_leaf),
)
def test_tree_structure(self):
obj = [1, 2, (3, 4)]
self.assertEqual(
@ -1190,6 +1232,14 @@ class TreeAliasTest(jtu.JaxTestCase):
tree_util.tree_structure(obj),
)
def test_tree_structure_is_leaf(self):
obj = [1, 2, (3, 4)]
is_leaf = lambda x: isinstance(x, tuple)
self.assertEqual(
jax.tree.structure(obj, is_leaf=is_leaf),
tree_util.tree_structure(obj, is_leaf=is_leaf),
)
def test_tree_transpose(self):
obj = [(1, 2), (3, 4), (5, 6)]
outer_treedef = tree_util.tree_structure(['*', '*', '*'])