tree_transpose: optionally infer inner_treedef

This commit is contained in:
Jake VanderPlas 2024-02-15 12:01:21 -08:00
parent 3708336f8f
commit 6ffea0ba1f
3 changed files with 32 additions and 4 deletions

View File

@ -15,6 +15,8 @@ Remember to align the itemized text with the first line of an item within a list
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
in {mod}`jax.tree_util`.
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.
* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D

View File

@ -316,13 +316,15 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
return treedef.from_iterable_tree(xs)
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
pytree_to_transpose: Any) -> Any:
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
Args:
outer_treedef: PyTreeDef representing the outer tree.
inner_treedef: PyTreeDef representing the inner tree.
If None, then it will be inferred from outer_treedef and the structure of
pytree_to_transpose.
pytree_to_transpose: the pytree to be transposed.
Returns:
@ -335,8 +337,15 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
>>> outer_structure = jax.tree.structure(['*', '*'])
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
([1, 4], [2, 5], [3, 6])
Inferring the inner structure:
>>> jax.tree.transpose(outer_structure, None, tree)
([1, 4], [2, 5], [3, 6])
"""
flat, treedef = tree_flatten(pytree_to_transpose)
if inner_treedef is None:
inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
inner_size = inner_treedef.num_leaves
outer_size = outer_treedef.num_leaves
if treedef.num_leaves != (inner_size * outer_size):

View File

@ -519,10 +519,27 @@ class TreeTest(jtu.JaxTestCase):
outer_treedef = tree_util.tree_structure(tree)
if not outer_treedef.num_leaves:
self.skipTest("Skipping empty tree")
inner_treedef = tree_util.tree_structure([1, 1, 1])
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
def make_inner(x):
return [x, x, x]
inner_treedef = tree_util.tree_structure(make_inner(1))
nested = tree_util.tree_map(make_inner, tree)
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
self.assertEqual(actual, [tree, tree, tree])
self.assertEqual(actual, make_inner(tree))
@parameterized.parameters(*TREES)
def testTransposeInferInnerTreedef(self, tree):
if isinstance(tree, FlatCache):
# The tree_map construction below fails for FlatCache, because
# the cached metadata becomes out of sync.
self.skipTest("Test does not work properly for FlatCache.")
outer_treedef = tree_util.tree_structure(tree)
if not outer_treedef.num_leaves:
self.skipTest("Skipping empty tree")
def make_inner(x):
return [x, {'a': x}, (x,)]
nested = tree_util.tree_map(make_inner, tree)
actual = tree_util.tree_transpose(outer_treedef, None, nested)
self.assertEqual(actual, make_inner(tree))
def testTransposeMismatchOuter(self):
tree = {"a": [1, 2], "b": [3, 4]}