mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
tree_all: add support for is_leaf
This commit is contained in:
parent
0739d520b1
commit
814b32a44b
@ -18,6 +18,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`from jax.experimental.export import export`, and instead you should use
|
||||
`from jax.experimental import export`.
|
||||
The removed functionality has been deprecated since 0.4.24.
|
||||
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
|
||||
|
||||
* Deprecations
|
||||
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
|
||||
|
@ -20,9 +20,9 @@ from jax._src import tree_util
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def all(tree: Any) -> bool:
|
||||
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||
"""Alias of :func:`jax.tree_util.tree_all`."""
|
||||
return tree_util.tree_all(tree)
|
||||
return tree_util.tree_all(tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def flatten(tree: Any,
|
||||
|
@ -623,11 +623,15 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
|
||||
|
||||
@export
|
||||
def tree_all(tree: Any) -> bool:
|
||||
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||
"""Call all() over the leaves of a tree.
|
||||
|
||||
Args:
|
||||
tree: the pytree to evaluate
|
||||
is_leaf : an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
result: boolean True or False
|
||||
@ -643,7 +647,7 @@ def tree_all(tree: Any) -> bool:
|
||||
- :func:`jax.tree_util.tree_reduce`
|
||||
- :func:`jax.tree_util.tree_leaves`
|
||||
"""
|
||||
return all(tree_leaves(tree))
|
||||
return all(tree_leaves(tree, is_leaf=is_leaf))
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
|
@ -1153,6 +1153,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_all(obj),
|
||||
)
|
||||
|
||||
def test_tree_all_is_leaf(self):
|
||||
obj = [True, True, (True, False)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.all(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_all(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_flatten(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
self.assertEqual(
|
||||
@ -1160,6 +1168,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_flatten(obj),
|
||||
)
|
||||
|
||||
def test_tree_flatten_is_leaf(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.flatten(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_flatten(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_leaves(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
self.assertEqual(
|
||||
@ -1167,6 +1183,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_leaves(obj),
|
||||
)
|
||||
|
||||
def test_tree_leaves_is_leaf(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.leaves(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_leaves(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_map(self):
|
||||
func = lambda x: x * 2
|
||||
obj = [1, 2, (3, 4)]
|
||||
@ -1175,6 +1199,15 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_map(func, obj),
|
||||
)
|
||||
|
||||
def test_tree_map_is_leaf(self):
|
||||
func = lambda x: x * 2
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.map(func, obj, is_leaf=is_leaf),
|
||||
tree_util.tree_map(func, obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_reduce(self):
|
||||
func = lambda a, b: a + b
|
||||
obj = [1, 2, (3, 4)]
|
||||
@ -1183,6 +1216,15 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_reduce(func, obj),
|
||||
)
|
||||
|
||||
def test_tree_reduce_is_leaf(self):
|
||||
func = lambda a, b: a + b
|
||||
obj = [(1, 2), (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.reduce(func, obj, is_leaf=is_leaf),
|
||||
tree_util.tree_reduce(func, obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_structure(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
self.assertEqual(
|
||||
@ -1190,6 +1232,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
tree_util.tree_structure(obj),
|
||||
)
|
||||
|
||||
def test_tree_structure_is_leaf(self):
|
||||
obj = [1, 2, (3, 4)]
|
||||
is_leaf = lambda x: isinstance(x, tuple)
|
||||
self.assertEqual(
|
||||
jax.tree.structure(obj, is_leaf=is_leaf),
|
||||
tree_util.tree_structure(obj, is_leaf=is_leaf),
|
||||
)
|
||||
|
||||
def test_tree_transpose(self):
|
||||
obj = [(1, 2), (3, 4), (5, 6)]
|
||||
outer_treedef = tree_util.tree_structure(['*', '*', '*'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user