diff --git a/CHANGELOG.md b/CHANGELOG.md index e7823a3ee..c3b306b4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Remember to align the itemized text with the first line of an item within a list `from jax.experimental.export import export`, and instead you should use `from jax.experimental import export`. The removed functionality has been deprecated since 0.4.24. + * Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`. * Deprecations * `jax.sharding.XLACompatibleSharding` is deprecated. Please use diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 9fdd8ecb8..be8c70732 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -20,9 +20,9 @@ from jax._src import tree_util T = TypeVar("T") -def all(tree: Any) -> bool: +def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: """Alias of :func:`jax.tree_util.tree_all`.""" - return tree_util.tree_all(tree) + return tree_util.tree_all(tree, is_leaf=is_leaf) def flatten(tree: Any, diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 857ad46ff..b502c3b0b 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -623,11 +623,15 @@ def tree_reduce(function: Callable[[T, Any], T], @export -def tree_all(tree: Any) -> bool: +def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: """Call all() over the leaves of a tree. Args: tree: the pytree to evaluate + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. Returns: result: boolean True or False @@ -643,7 +647,7 @@ def tree_all(tree: Any) -> bool: - :func:`jax.tree_util.tree_reduce` - :func:`jax.tree_util.tree_leaves` """ - return all(tree_leaves(tree)) + return all(tree_leaves(tree, is_leaf=is_leaf)) register_pytree_node( diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 74e21eecd..23ddf7390 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -1153,6 +1153,14 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_all(obj), ) + def test_tree_all_is_leaf(self): + obj = [True, True, (True, False)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.all(obj, is_leaf=is_leaf), + tree_util.tree_all(obj, is_leaf=is_leaf), + ) + def test_tree_flatten(self): obj = [1, 2, (3, 4)] self.assertEqual( @@ -1160,6 +1168,14 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_flatten(obj), ) + def test_tree_flatten_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.flatten(obj, is_leaf=is_leaf), + tree_util.tree_flatten(obj, is_leaf=is_leaf), + ) + def test_tree_leaves(self): obj = [1, 2, (3, 4)] self.assertEqual( @@ -1167,6 +1183,14 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_leaves(obj), ) + def test_tree_leaves_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.leaves(obj, is_leaf=is_leaf), + tree_util.tree_leaves(obj, is_leaf=is_leaf), + ) + def test_tree_map(self): func = lambda x: x * 2 obj = [1, 2, (3, 4)] @@ -1175,6 +1199,15 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_map(func, obj), ) + def test_tree_map_is_leaf(self): + func = lambda x: x * 2 + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.map(func, obj, is_leaf=is_leaf), + tree_util.tree_map(func, obj, is_leaf=is_leaf), + ) + def test_tree_reduce(self): func = lambda a, b: a + b obj = [1, 2, (3, 4)] @@ -1183,6 +1216,15 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_reduce(func, obj), ) + def test_tree_reduce_is_leaf(self): + func = lambda a, b: a + b + obj = [(1, 2), (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.reduce(func, obj, is_leaf=is_leaf), + tree_util.tree_reduce(func, obj, is_leaf=is_leaf), + ) + def test_tree_structure(self): obj = [1, 2, (3, 4)] self.assertEqual( @@ -1190,6 +1232,14 @@ class TreeAliasTest(jtu.JaxTestCase): tree_util.tree_structure(obj), ) + def test_tree_structure_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.structure(obj, is_leaf=is_leaf), + tree_util.tree_structure(obj, is_leaf=is_leaf), + ) + def test_tree_transpose(self): obj = [(1, 2), (3, 4), (5, 6)] outer_treedef = tree_util.tree_structure(['*', '*', '*'])