# Copyright 2019 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import functools import pickle import re from absl.testing import absltest from absl.testing import parameterized import jax from jax import tree_util from jax import flatten_util from jax._src import test_util as jtu from jax._src.tree_util import prefix_errors, flatten_one_level import jax.numpy as jnp def _dummy_func(*args, **kwargs): return ATuple = collections.namedtuple("ATuple", ("foo", "bar")) class ANamedTupleSubclass(ATuple): pass ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar")) tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar), lambda bar, foo: ATuple2(foo[0], bar)) class AnObject: def __init__(self, x, y, z): self.x = x self.y = y self.z = z def __eq__(self, other): return self.x == other.x and self.y == other.y and self.z == other.z def __hash__(self): return hash((self.x, self.y, self.z)) def __repr__(self): return f"AnObject({self.x},{self.y},{self.z})" tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z), lambda z, xy: AnObject(xy[0], xy[1], z)) class AnObject2(AnObject): pass tree_util.register_pytree_with_keys( AnObject2, lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved) ) @tree_util.register_pytree_node_class class Special: def __init__(self, x, y): self.x = x self.y = y def __repr__(self): return f"Special(x={self.x}, y={self.y})" def tree_flatten(self): return ((self.x, self.y), None) @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) def __eq__(self, other): return type(self) is type(other) and (self.x, self.y) == (other.x, other.y) @tree_util.register_pytree_with_keys_class class SpecialWithKeys(Special): def tree_flatten_with_keys(self): return (((tree_util.GetAttrKey('x'), self.x), (tree_util.GetAttrKey('y'), self.y)), None) @tree_util.register_pytree_node_class class FlatCache: def __init__(self, structured, *, leaves=None, treedef=None): if treedef is None: leaves, treedef = tree_util.tree_flatten(structured) self._structured = structured self.treedef = treedef self.leaves = leaves def __hash__(self): return hash(self.structured) def __eq__(self, other): return self.structured == other.structured def __repr__(self): return f"FlatCache({self.structured!r})" @property def structured(self): if self._structured is None: self._structured = tree_util.tree_unflatten(self.treedef, self.leaves) return self._structured def tree_flatten(self): return self.leaves, self.treedef @classmethod def tree_unflatten(cls, meta, data): if not tree_util.all_leaves(data): data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data)) return FlatCache(None, leaves=data, treedef=meta) TREES = ( (None,), ((None,),), ((),), (([()]),), ((1, 2),), (((1, "foo"), ["bar", (3, None, 7)]),), ([3],), ([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],), ([AnObject(3, None, [4, "foo"])],), ([AnObject2(3, None, [4, "foo"])],), (Special(2, 3.),), ({"a": 1, "b": 2},), (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), (collections.defaultdict(dict, [("foo", 34), ("baz", 101), ("something", -42)]),), (ANamedTupleSubclass(foo="hello", bar=3.5),), (FlatCache(None),), (FlatCache(1),), (FlatCache({"a": [1, 2]}),), ) TREE_STRINGS = ( "PyTreeDef(None)", "PyTreeDef((None,))", "PyTreeDef(())", "PyTreeDef([()])", "PyTreeDef((*, *))", "PyTreeDef(((*, *), [*, (*, None, *)]))", "PyTreeDef([*])", ("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, " "CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"), "PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])", "PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])", "PyTreeDef(CustomNode(Special[None], [*, *]))", "PyTreeDef({'a': *, 'b': *})", ) # pytest expects "tree_util_test.ATuple" STRS = [] for tree_str in TREE_STRINGS: tree_str = re.escape(tree_str) tree_str = tree_str.replace("__main__", ".*") STRS.append(tree_str) TREE_STRINGS = STRS LEAVES = ( ("foo",), (0.1,), (1,), (object(),), ) # All except those decorated by register_pytree_node_class TREES_WITH_KEYPATH = ( (None,), ((None,),), ((),), (([()]),), ((1, 0),), (((1, "foo"), ["bar", (3, None, 7)]),), ([3],), ([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],), ([AnObject2(3, None, [4, "foo"])],), (SpecialWithKeys(2, 3.),), ({"a": 1, "b": 0},), (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), (collections.defaultdict(dict, [("foo", 34), ("baz", 101), ("something", -42)]),), (ANamedTupleSubclass(foo="hello", bar=3.5),), ) class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*(TREES + LEAVES)) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func),), (tree_util.Partial(_dummy_func, 1, 2),), (tree_util.Partial(_dummy_func, x="a"),), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) def testPartialDoesNotMergeWithOtherPartials(self): def f(a, b, c): pass g = functools.partial(f, 2) h = tree_util.Partial(g, 3) self.assertEqual(h.args, (3,)) def testPartialFuncAttributeHasStableHash(self): # https://github.com/google/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) self.assertEqual(fun, p1.func) self.assertEqual(p1.func, fun) self.assertEqual(p1.func, p2.func) self.assertEqual(hash(p1.func), hash(p2.func)) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4,))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7,)) self.assertEqual([c0, c1], tree.children()) def testTreedefTupleFromChildren(self): # https://github.com/google/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) self.assertEqual(treedef1.num_leaves, len(leaves)) self.assertEqual(treedef1.num_leaves, treedef2.num_leaves) self.assertEqual(treedef1.num_nodes, treedef2.num_nodes) def testTreedefTupleComparesEqual(self): # https://github.com/google/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) def testFlattenOrder(self): flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9]) flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9]) flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9]) self.assertEqual(flat1, list(range(10))) self.assertEqual(flat2, list(range(10))) self.assertEqual(flat3, list(range(10))) def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) out = tree.flatten_up_to([({ "foo": 7 }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) def testTreeMap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_map(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, {"foo": "bar"}), (4, 7), (5, [5, 6])))) def testTreeMapWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_map(lambda *xs: tuple(xs), x, y, is_leaf=lambda n: isinstance(n, list)) self.assertEqual(out, (((1, [3]), (2, None)), (([3, 4, 5], ({"foo": "bar"}, 7, [5, 6]))))) @parameterized.parameters( tree_util.tree_leaves, lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0]) def testFlattenIsLeaf(self, leaf_fn): x = [(1, 2), (3, 4), (5, 6)] leaves = leaf_fn(x, is_leaf=lambda t: False) self.assertEqual(leaves, [1, 2, 3, 4, 5, 6]) leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, x) leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list)) self.assertEqual(leaves, [x]) leaves = leaf_fn(x, is_leaf=lambda t: True) self.assertEqual(leaves, [x]) y = [[[(1,)], [[(2,)], {"a": (3,)}]]] leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, [(1,), (2,), (3,)]) @parameterized.parameters( tree_util.tree_structure, lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1]) def testStructureIsLeaf(self, structure_fn): x = [(1, 2), (3, 4), (5, 6)] treedef = structure_fn(x, is_leaf=lambda t: False) self.assertEqual(treedef.num_leaves, 6) treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(treedef.num_leaves, 3) treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list)) self.assertEqual(treedef.num_leaves, 1) treedef = structure_fn(x, is_leaf=lambda t: True) self.assertEqual(treedef.num_leaves, 1) y = [[[(1,)], [[(2,)], {"a": (3,)}]]] treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(treedef.num_leaves, 3) @parameterized.parameters(*TREES) def testRoundtripIsLeaf(self, tree): xs, treedef = tree_util.tree_flatten( tree, is_leaf=lambda t: isinstance(t, tuple)) recon_tree = tree_util.tree_unflatten(treedef, xs) self.assertEqual(recon_tree, tree) @parameterized.parameters(*TREES) def testAllLeavesWithTrees(self, tree): leaves = tree_util.tree_leaves(tree) self.assertTrue(tree_util.all_leaves(leaves)) self.assertFalse(tree_util.all_leaves([tree])) @parameterized.parameters(*LEAVES) def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) inner_treedef = tree_util.tree_structure(["*", "*", "*"]) composed_treedef = treedef.compose(inner_treedef) expected_leaves = treedef.num_leaves * inner_treedef.num_leaves self.assertEqual(composed_treedef.num_leaves, expected_leaves) expected_nodes = ((treedef.num_nodes - treedef.num_leaves) + (inner_treedef.num_nodes * treedef.num_leaves)) self.assertEqual(composed_treedef.num_nodes, expected_nodes) leaves = [1] * expected_leaves composed = tree_util.tree_unflatten(composed_treedef, leaves) self.assertEqual(leaves, tree_util.tree_leaves(composed)) @parameterized.parameters(*TREES) def testTranspose(self, tree): outer_treedef = tree_util.tree_structure(tree) if not outer_treedef.num_leaves: self.skipTest("Skipping empty tree") inner_treedef = tree_util.tree_structure([1, 1, 1]) nested = tree_util.tree_map(lambda x: [x, x, x], tree) actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested) self.assertEqual(actual, [tree, tree, tree]) def testTransposeMismatchOuter(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3}) inner_treedef = tree_util.tree_structure([1, 2]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeMismatchInner(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2}) inner_treedef = tree_util.tree_structure([1, 2, 3]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeWithCustomObject(self): outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2})) inner_treedef = tree_util.tree_structure([1, 2]) expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})] actual = tree_util.tree_transpose(outer_treedef, inner_treedef, FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" treedef = tree_util.tree_structure(tree) self.assertRegex(str(treedef), correct_string) def testTreeDefWithEmptyDictStringRepresentation(self): self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})") @parameterized.parameters(*TREES) def testPickleRoundTrip(self, tree): treedef = tree_util.tree_structure(tree) treedef_restored = pickle.loads(pickle.dumps(treedef)) self.assertEqual(treedef, treedef_restored) def testDictKeysSortable(self): d = {"a": 1, 2: "b"} with self.assertRaisesRegex(TypeError, "'<' not supported"): _, _ = tree_util.tree_flatten(d) def testFlattenDictKeyOrder(self): d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}} leaves, treedef = tree_util.tree_flatten(d) self.assertEqual(leaves, [1, 2, 1, 2]) self.assertEqual( str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})" ) restored_d = tree_util.tree_unflatten(treedef, leaves) self.assertEqual(list(restored_d.keys()), ["a", "b", "c"]) def testWalk(self): d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}} leaves, treedef = tree_util.tree_flatten(d) nodes_visited = [] node_data_visited = [] leaves_visited = [] def f_node(node, node_data): nodes_visited.append(node) node_data_visited.append(node_data) def f_leaf(leaf): leaves_visited.append(leaf) treedef.walk(f_node, f_leaf, leaves) self.assertEqual(leaves_visited, [1, 2, 1, 2]) self.assertEqual(nodes_visited, [(None, None), (None, None, None)]) self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]]) @parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES)) def testRoundtripWithPath(self, inputs): key_leaves, treedef = tree_util.tree_flatten_with_path(inputs) actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves]) self.assertEqual(actual, inputs) def testTreeMapWithPath(self): tree = [{i: i for i in range(10)}] all_zeros = tree_util.tree_map_with_path( lambda kp, val: val - kp[1].key + kp[0].idx, tree ) self.assertEqual(all_zeros, [{i: 0 for i in range(10)}]) def testTreeMapWithPathMultipleTrees(self): tree1 = [AnObject2(x=12, y={'cin': [1, 4, 10], 'bar': None}, z='constantdef'), 5] tree2 = [AnObject2(x=2, y={'cin': [2, 2, 2], 'bar': None}, z='constantdef'), 2] from_two_trees = tree_util.tree_map_with_path( lambda kp, a, b: a + b, tree1, tree2 ) from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1) self.assertEqual(from_two_trees, from_one_tree) def testKeyStr(self): tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)] flattened, _ = tree_util.tree_flatten_with_path(tree1) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, [ "[0].foo: 12", "[0].bar['cin'][0]: 1", "[0].bar['cin'][1]: 4", "[0].bar['cin'][2]: 10", "[1]: [0 1 2 3 4]", ], ) def testTreeMapWithPathWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], jnp.array((0))), ([0], 7, [5, 6])) out = tree_util.tree_map_with_path( lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y, is_leaf=lambda n: isinstance(n, list)) self.assertEqual(out, (((0, 1, [3]), (0, 2, jnp.array((0)))), (1, [3, 4, 5], ([0], 7, [5, 6])))) def testFlattenWithPathWithIsLeafArgument(self): def is_empty(x): try: children, _ = flatten_one_level(x) except ValueError: return True # Cannot flatten x; means it must be a leaf return len(children) == 0 EmptyTuple = collections.namedtuple("EmptyTuple", ()) tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, [ "['a']: 1", "['obj']x: EmptyTuple()", "['obj']y: 0", "['sub'][0]: [1 2]", "['sub'][1].foo: ()", "['sub'][1].bar[0]: None", ], ) def testFlattenOneLevel(self): EmptyTuple = collections.namedtuple("EmptyTuple", ()) tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} self.assertEqual(flatten_one_level(tree1["sub"])[0], tree1["sub"]) self.assertEqual(flatten_one_level(tree1["sub"][1])[0], [(), [None]]) self.assertEqual(flatten_one_level(tree1["obj"])[0], [EmptyTuple(), 0]) with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): flatten_one_level(1) with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): flatten_one_level(jnp.array((1, 2))) def testOptionalFlatten(self): @tree_util.register_pytree_with_keys_class class FooClass: def __init__(self, x, y): self.x = x self.y = y def tree_flatten(self): return ((self.x, self.y), 'treedef') def tree_flatten_with_keys(self): return (((tree_util.GetAttrKey('x'), self.x), (tree_util.GetAttrKey('x'), self.y)), 'treedef') @classmethod def tree_unflatten(cls, _, children): return cls(*children) tree = FooClass(x=1, y=2) self.assertEqual( str(tree_util.tree_flatten(tree)[1]), "PyTreeDef(CustomNode(FooClass[treedef], [*, *]))", ) self.assertEqual( str(tree_util.tree_flatten_with_path(tree)[1]), "PyTreeDef(CustomNode(FooClass[treedef], [*, *]))", ) self.assertEqual(tree_util.tree_flatten(tree)[0], [l for _, l in tree_util.tree_flatten_with_path(tree)[0]]) def testPyTreeWithoutKeysIsntTreatedAsLeaf(self): leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4])) self.assertLen(leaves, 4) def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self): leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi')) self.assertLen(leaves, 1) class RavelUtilTest(jtu.JaxTestCase): def testFloats(self): tree = [ jnp.array([3.0], jnp.float32), jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32), ] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.float32) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testInts(self): tree = [jnp.array([3], jnp.int32), jnp.array([[1, 2], [3, 4]], jnp.int32)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.int32) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedFloatInt(self): tree = [jnp.array([3], jnp.int32), jnp.array([[1., 2.], [3., 4.]], jnp.float32)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedIntBool(self): tree = [jnp.array([0], jnp.bool_), jnp.array([[1, 2], [3, 4]], jnp.int32)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedFloatComplex(self): tree = [jnp.array([1.], jnp.float32), jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testEmpty(self): tree = [] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.float32) # convention tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testDtypePolymorphicUnravel(self): # https://github.com/google/jax/issues/7809 x = jnp.arange(10, dtype=jnp.float32) x_flat, unravel = flatten_util.ravel_pytree(x) y = x_flat < 5.3 x_ = unravel(y) self.assertEqual(x_.dtype, y.dtype) @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testDtypeMonomorphicUnravel(self): # https://github.com/google/jax/issues/7809 x1 = jnp.arange(10, dtype=jnp.float32) x2 = jnp.arange(10, dtype=jnp.int32) x_flat, unravel = flatten_util.ravel_pytree((x1, x2)) y = x_flat < 5.3 with self.assertRaisesRegex(TypeError, 'but expected dtype'): _ = unravel(y) def test_no_recompile(self): x1 = jnp.array([1, 2]) x2 = jnp.array([3, 4]) x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2)) x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2)) num_traces = 0 def run(flat, unravel): nonlocal num_traces num_traces += 1 flat = flat + 1 return unravel(flat) run = jax.jit(run, static_argnums=1) run(x_flat1, unravel1) run(x_flat2, unravel2) self.assertEqual(num_traces, 1) class TreePrefixErrorsTest(jtu.JaxTestCase): def test_different_types(self): e, = prefix_errors((1, 2), [1, 2]) expected = ("pytree structure error: different types at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_types_nested(self): e, = prefix_errors(((1,), (2,)), ([3], (4,))) expected = ("pytree structure error: different types at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_types_multiple(self): e1, e2 = prefix_errors(((1,), (2,)), ([3], [4])) expected = ("pytree structure error: different types at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = ("pytree structure error: different types at key path\n" r" in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes') def test_different_num_children_tuple(self): e, = prefix_errors((1,), (2, 3)) expected = ("pytree structure error: different lengths of tuple " "at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_num_children_list(self): e, = prefix_errors([1], [2, 3]) expected = ("pytree structure error: different lengths of list " "at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_num_children_generic(self): e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3}) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_num_children_nested(self): e, = prefix_errors([[1]], [[2, 3]]) expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_num_children_multiple(self): e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]]) expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes') def test_different_num_children_print_key_diff(self): e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3}) expected = ("so the symmetric difference on key sets is\n" " b") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_metadata(self): e, = prefix_errors({1: 2}, {3: 4}) expected = ("pytree structure error: different pytree metadata " "at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_metadata_nested(self): e, = prefix_errors([{1: 2}], [{3: 4}]) expected = ("pytree structure error: different pytree metadata " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_different_metadata_multiple(self): e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}]) expected = ("pytree structure error: different pytree metadata " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = ("pytree structure error: different pytree metadata " "at key path\n" r" in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes') def test_fallback_keypath(self): e, = prefix_errors(Special(1, [2]), Special(3, 4)) expected = ("pytree structure error: different types at key path\n" r" in_axes\[\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') def test_no_errors(self): () = prefix_errors((1, 2), ((11, 12, 13), 2)) def test_different_structure_no_children(self): e, = prefix_errors({}, {'a': []}) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())