mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
tree_transpose: optionally infer inner_treedef
This commit is contained in:
parent
3708336f8f
commit
6ffea0ba1f
@ -15,6 +15,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
|
||||
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
|
||||
in {mod}`jax.tree_util`.
|
||||
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
|
||||
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.
|
||||
|
||||
* Deprecations & Removals
|
||||
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
|
||||
|
@ -316,13 +316,15 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
|
||||
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
|
||||
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
|
||||
pytree_to_transpose: Any) -> Any:
|
||||
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
|
||||
|
||||
Args:
|
||||
outer_treedef: PyTreeDef representing the outer tree.
|
||||
inner_treedef: PyTreeDef representing the inner tree.
|
||||
If None, then it will be inferred from outer_treedef and the structure of
|
||||
pytree_to_transpose.
|
||||
pytree_to_transpose: the pytree to be transposed.
|
||||
|
||||
Returns:
|
||||
@ -335,8 +337,15 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
|
||||
>>> outer_structure = jax.tree.structure(['*', '*'])
|
||||
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
|
||||
([1, 4], [2, 5], [3, 6])
|
||||
|
||||
Inferring the inner structure:
|
||||
|
||||
>>> jax.tree.transpose(outer_structure, None, tree)
|
||||
([1, 4], [2, 5], [3, 6])
|
||||
"""
|
||||
flat, treedef = tree_flatten(pytree_to_transpose)
|
||||
if inner_treedef is None:
|
||||
inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
|
||||
inner_size = inner_treedef.num_leaves
|
||||
outer_size = outer_treedef.num_leaves
|
||||
if treedef.num_leaves != (inner_size * outer_size):
|
||||
|
@ -519,10 +519,27 @@ class TreeTest(jtu.JaxTestCase):
|
||||
outer_treedef = tree_util.tree_structure(tree)
|
||||
if not outer_treedef.num_leaves:
|
||||
self.skipTest("Skipping empty tree")
|
||||
inner_treedef = tree_util.tree_structure([1, 1, 1])
|
||||
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
|
||||
def make_inner(x):
|
||||
return [x, x, x]
|
||||
inner_treedef = tree_util.tree_structure(make_inner(1))
|
||||
nested = tree_util.tree_map(make_inner, tree)
|
||||
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
|
||||
self.assertEqual(actual, [tree, tree, tree])
|
||||
self.assertEqual(actual, make_inner(tree))
|
||||
|
||||
@parameterized.parameters(*TREES)
|
||||
def testTransposeInferInnerTreedef(self, tree):
|
||||
if isinstance(tree, FlatCache):
|
||||
# The tree_map construction below fails for FlatCache, because
|
||||
# the cached metadata becomes out of sync.
|
||||
self.skipTest("Test does not work properly for FlatCache.")
|
||||
outer_treedef = tree_util.tree_structure(tree)
|
||||
if not outer_treedef.num_leaves:
|
||||
self.skipTest("Skipping empty tree")
|
||||
def make_inner(x):
|
||||
return [x, {'a': x}, (x,)]
|
||||
nested = tree_util.tree_map(make_inner, tree)
|
||||
actual = tree_util.tree_transpose(outer_treedef, None, nested)
|
||||
self.assertEqual(actual, make_inner(tree))
|
||||
|
||||
def testTransposeMismatchOuter(self):
|
||||
tree = {"a": [1, 2], "b": [3, 4]}
|
||||
|
Loading…
x
Reference in New Issue
Block a user