mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
tree_all: add support for is_leaf
This commit is contained in:
parent
0739d520b1
commit
814b32a44b
@ -18,6 +18,7 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
`from jax.experimental.export import export`, and instead you should use
|
`from jax.experimental.export import export`, and instead you should use
|
||||||
`from jax.experimental import export`.
|
`from jax.experimental import export`.
|
||||||
The removed functionality has been deprecated since 0.4.24.
|
The removed functionality has been deprecated since 0.4.24.
|
||||||
|
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
|
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
|
||||||
|
@ -20,9 +20,9 @@ from jax._src import tree_util
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def all(tree: Any) -> bool:
|
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||||
"""Alias of :func:`jax.tree_util.tree_all`."""
|
"""Alias of :func:`jax.tree_util.tree_all`."""
|
||||||
return tree_util.tree_all(tree)
|
return tree_util.tree_all(tree, is_leaf=is_leaf)
|
||||||
|
|
||||||
|
|
||||||
def flatten(tree: Any,
|
def flatten(tree: Any,
|
||||||
|
@ -623,11 +623,15 @@ def tree_reduce(function: Callable[[T, Any], T],
|
|||||||
|
|
||||||
|
|
||||||
@export
|
@export
|
||||||
def tree_all(tree: Any) -> bool:
|
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||||
"""Call all() over the leaves of a tree.
|
"""Call all() over the leaves of a tree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree: the pytree to evaluate
|
tree: the pytree to evaluate
|
||||||
|
is_leaf : an optionally specified function that will be called at each
|
||||||
|
flattening step. It should return a boolean, which indicates whether the
|
||||||
|
flattening should traverse the current object, or if it should be stopped
|
||||||
|
immediately, with the whole subtree being treated as a leaf.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: boolean True or False
|
result: boolean True or False
|
||||||
@ -643,7 +647,7 @@ def tree_all(tree: Any) -> bool:
|
|||||||
- :func:`jax.tree_util.tree_reduce`
|
- :func:`jax.tree_util.tree_reduce`
|
||||||
- :func:`jax.tree_util.tree_leaves`
|
- :func:`jax.tree_util.tree_leaves`
|
||||||
"""
|
"""
|
||||||
return all(tree_leaves(tree))
|
return all(tree_leaves(tree, is_leaf=is_leaf))
|
||||||
|
|
||||||
|
|
||||||
register_pytree_node(
|
register_pytree_node(
|
||||||
|
@ -1153,6 +1153,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_all(obj),
|
tree_util.tree_all(obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_all_is_leaf(self):
|
||||||
|
obj = [True, True, (True, False)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.all(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_all(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_flatten(self):
|
def test_tree_flatten(self):
|
||||||
obj = [1, 2, (3, 4)]
|
obj = [1, 2, (3, 4)]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -1160,6 +1168,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_flatten(obj),
|
tree_util.tree_flatten(obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_flatten_is_leaf(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.flatten(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_flatten(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_leaves(self):
|
def test_tree_leaves(self):
|
||||||
obj = [1, 2, (3, 4)]
|
obj = [1, 2, (3, 4)]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -1167,6 +1183,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_leaves(obj),
|
tree_util.tree_leaves(obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_leaves_is_leaf(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.leaves(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_leaves(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_map(self):
|
def test_tree_map(self):
|
||||||
func = lambda x: x * 2
|
func = lambda x: x * 2
|
||||||
obj = [1, 2, (3, 4)]
|
obj = [1, 2, (3, 4)]
|
||||||
@ -1175,6 +1199,15 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_map(func, obj),
|
tree_util.tree_map(func, obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_map_is_leaf(self):
|
||||||
|
func = lambda x: x * 2
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.map(func, obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_map(func, obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_reduce(self):
|
def test_tree_reduce(self):
|
||||||
func = lambda a, b: a + b
|
func = lambda a, b: a + b
|
||||||
obj = [1, 2, (3, 4)]
|
obj = [1, 2, (3, 4)]
|
||||||
@ -1183,6 +1216,15 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_reduce(func, obj),
|
tree_util.tree_reduce(func, obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_reduce_is_leaf(self):
|
||||||
|
func = lambda a, b: a + b
|
||||||
|
obj = [(1, 2), (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.reduce(func, obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_reduce(func, obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_structure(self):
|
def test_tree_structure(self):
|
||||||
obj = [1, 2, (3, 4)]
|
obj = [1, 2, (3, 4)]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -1190,6 +1232,14 @@ class TreeAliasTest(jtu.JaxTestCase):
|
|||||||
tree_util.tree_structure(obj),
|
tree_util.tree_structure(obj),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tree_structure_is_leaf(self):
|
||||||
|
obj = [1, 2, (3, 4)]
|
||||||
|
is_leaf = lambda x: isinstance(x, tuple)
|
||||||
|
self.assertEqual(
|
||||||
|
jax.tree.structure(obj, is_leaf=is_leaf),
|
||||||
|
tree_util.tree_structure(obj, is_leaf=is_leaf),
|
||||||
|
)
|
||||||
|
|
||||||
def test_tree_transpose(self):
|
def test_tree_transpose(self):
|
||||||
obj = [(1, 2), (3, 4), (5, 6)]
|
obj = [(1, 2), (3, 4), (5, 6)]
|
||||||
outer_treedef = tree_util.tree_structure(['*', '*', '*'])
|
outer_treedef = tree_util.tree_structure(['*', '*', '*'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user