rocm_jax/tests/tree_util_test.py
Peter Hawkins 319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00

817 lines
29 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import pickle
import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import tree_util
from jax import flatten_util
from jax._src import test_util as jtu
from jax._src.tree_util import prefix_errors, flatten_one_level
import jax.numpy as jnp
def _dummy_func(*args, **kwargs):
return
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
class ANamedTupleSubclass(ATuple):
pass
ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
lambda bar, foo: ATuple2(foo[0], bar))
class AnObject:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def __eq__(self, other):
return self.x == other.x and self.y == other.y and self.z == other.z
def __hash__(self):
return hash((self.x, self.y, self.z))
def __repr__(self):
return f"AnObject({self.x},{self.y},{self.z})"
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
lambda z, xy: AnObject(xy[0], xy[1], z))
class AnObject2(AnObject): pass
tree_util.register_pytree_with_keys(
AnObject2,
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved)
)
@tree_util.register_pytree_node_class
class Special:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return f"Special(x={self.x}, y={self.y})"
def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __eq__(self, other):
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
@tree_util.register_pytree_with_keys_class
class SpecialWithKeys(Special):
def tree_flatten_with_keys(self):
return (((tree_util.GetAttrKey('x'), self.x),
(tree_util.GetAttrKey('y'), self.y)), None)
@tree_util.register_pytree_node_class
class FlatCache:
def __init__(self, structured, *, leaves=None, treedef=None):
if treedef is None:
leaves, treedef = tree_util.tree_flatten(structured)
self._structured = structured
self.treedef = treedef
self.leaves = leaves
def __hash__(self):
return hash(self.structured)
def __eq__(self, other):
return self.structured == other.structured
def __repr__(self):
return f"FlatCache({self.structured!r})"
@property
def structured(self):
if self._structured is None:
self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
return self._structured
def tree_flatten(self):
return self.leaves, self.treedef
@classmethod
def tree_unflatten(cls, meta, data):
if not tree_util.all_leaves(data):
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
return FlatCache(None, leaves=data, treedef=meta)
TREES = (
(None,),
((None,),),
((),),
(([()]),),
((1, 2),),
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
([AnObject2(3, None, [4, "foo"])],),
(Special(2, 3.),),
({"a": 1, "b": 2},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
(FlatCache(None),),
(FlatCache(1),),
(FlatCache({"a": [1, 2]}),),
)
TREE_STRINGS = (
"PyTreeDef(None)",
"PyTreeDef((None,))",
"PyTreeDef(())",
"PyTreeDef([()])",
"PyTreeDef((*, *))",
"PyTreeDef(((*, *), [*, (*, None, *)]))",
"PyTreeDef([*])",
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(Special[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
)
# pytest expects "tree_util_test.ATuple"
STRS = []
for tree_str in TREE_STRINGS:
tree_str = re.escape(tree_str)
tree_str = tree_str.replace("__main__", ".*")
STRS.append(tree_str)
TREE_STRINGS = STRS
LEAVES = (
("foo",),
(0.1,),
(1,),
(object(),),
)
# All except those decorated by register_pytree_node_class
TREES_WITH_KEYPATH = (
(None,),
((None,),),
((),),
(([()]),),
((1, 0),),
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject2(3, None, [4, "foo"])],),
(SpecialWithKeys(2, 3.),),
({"a": 1, "b": 0},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
)
class TreeTest(jtu.JaxTestCase):
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtrip(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripWithFlattenUpTo(self, inputs):
_, tree = tree_util.tree_flatten(inputs)
xs = tree.flatten_up_to(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(
(tree_util.Partial(_dummy_func),),
(tree_util.Partial(_dummy_func, 1, 2),),
(tree_util.Partial(_dummy_func, x="a"),),
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
)
def testRoundtripPartial(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
# functools.partial does not support equality comparisons:
# https://stackoverflow.com/a/32786109/809705
self.assertEqual(actual.func, inputs.func)
self.assertEqual(actual.args, inputs.args)
self.assertEqual(actual.keywords, inputs.keywords)
def testPartialDoesNotMergeWithOtherPartials(self):
def f(a, b, c): pass
g = functools.partial(f, 2)
h = tree_util.Partial(g, 3)
self.assertEqual(h.args, (3,))
def testPartialFuncAttributeHasStableHash(self):
# https://github.com/google/jax/issues/9429
fun = functools.partial(print, 1)
p1 = tree_util.Partial(fun, 2)
p2 = tree_util.Partial(fun, 2)
self.assertEqual(fun, p1.func)
self.assertEqual(p1.func, fun)
self.assertEqual(p1.func, p2.func)
self.assertEqual(hash(p1.func), hash(p2.func))
def testChildren(self):
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
_, c0 = tree_util.tree_flatten((0, 0, 0))
_, c1 = tree_util.tree_flatten((7,))
self.assertEqual([c0, c1], tree.children())
def testTreedefTupleFromChildren(self):
# https://github.com/google/jax/issues/7377
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
treedef2 = tree_util.treedef_tuple(treedef1.children())
self.assertEqual(treedef1.num_leaves, len(leaves))
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def testTreedefTupleComparesEqual(self):
# https://github.com/google/jax/issues/9066
self.assertEqual(tree_util.tree_structure((3,)),
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
def testFlattenOrder(self):
flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9])
self.assertEqual(flat1, list(range(10)))
self.assertEqual(flat2, list(range(10)))
self.assertEqual(flat3, list(range(10)))
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
out = tree.flatten_up_to([({
"foo": 7
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
def testTreeMap(self):
x = ((1, 2), (3, 4, 5))
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
self.assertEqual(out, (((1, [3]), (2, None)),
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
def testTreeMapWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
def testTreeReduceWithIsLeafArgument(self):
out = tree_util.tree_reduce(lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]],
is_leaf=lambda l: isinstance(l, tuple))
self.assertEqual(out, (1, 2, 3, 4, 5, 6))
@parameterized.parameters(
tree_util.tree_leaves,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
def testFlattenIsLeaf(self, leaf_fn):
x = [(1, 2), (3, 4), (5, 6)]
leaves = leaf_fn(x, is_leaf=lambda t: False)
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, x)
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
self.assertEqual(leaves, [x])
leaves = leaf_fn(x, is_leaf=lambda t: True)
self.assertEqual(leaves, [x])
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, [(1,), (2,), (3,)])
@parameterized.parameters(
tree_util.tree_structure,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
def testStructureIsLeaf(self, structure_fn):
x = [(1, 2), (3, 4), (5, 6)]
treedef = structure_fn(x, is_leaf=lambda t: False)
self.assertEqual(treedef.num_leaves, 6)
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(treedef.num_leaves, 3)
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
self.assertEqual(treedef.num_leaves, 1)
treedef = structure_fn(x, is_leaf=lambda t: True)
self.assertEqual(treedef.num_leaves, 1)
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(treedef.num_leaves, 3)
@parameterized.parameters(*TREES)
def testRoundtripIsLeaf(self, tree):
xs, treedef = tree_util.tree_flatten(
tree, is_leaf=lambda t: isinstance(t, tuple))
recon_tree = tree_util.tree_unflatten(treedef, xs)
self.assertEqual(recon_tree, tree)
@parameterized.parameters(*TREES)
def testAllLeavesWithTrees(self, tree):
leaves = tree_util.tree_leaves(tree)
self.assertTrue(tree_util.all_leaves(leaves))
self.assertFalse(tree_util.all_leaves([tree]))
@parameterized.parameters(*LEAVES)
def testAllLeavesWithLeaves(self, leaf):
self.assertTrue(tree_util.all_leaves([leaf]))
@parameterized.parameters(*TREES)
def testCompose(self, tree):
treedef = tree_util.tree_structure(tree)
inner_treedef = tree_util.tree_structure(["*", "*", "*"])
composed_treedef = treedef.compose(inner_treedef)
expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
self.assertEqual(composed_treedef.num_leaves, expected_leaves)
expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
(inner_treedef.num_nodes * treedef.num_leaves))
self.assertEqual(composed_treedef.num_nodes, expected_nodes)
leaves = [1] * expected_leaves
composed = tree_util.tree_unflatten(composed_treedef, leaves)
self.assertEqual(leaves, tree_util.tree_leaves(composed))
@parameterized.parameters(*TREES)
def testTranspose(self, tree):
outer_treedef = tree_util.tree_structure(tree)
if not outer_treedef.num_leaves:
self.skipTest("Skipping empty tree")
inner_treedef = tree_util.tree_structure([1, 1, 1])
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
self.assertEqual(actual, [tree, tree, tree])
def testTransposeMismatchOuter(self):
tree = {"a": [1, 2], "b": [3, 4]}
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
inner_treedef = tree_util.tree_structure([1, 2])
with self.assertRaisesRegex(TypeError, "Mismatch"):
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
def testTransposeMismatchInner(self):
tree = {"a": [1, 2], "b": [3, 4]}
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
inner_treedef = tree_util.tree_structure([1, 2, 3])
with self.assertRaisesRegex(TypeError, "Mismatch"):
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
def testTransposeWithCustomObject(self):
outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
inner_treedef = tree_util.tree_structure([1, 2])
expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
FlatCache({"a": [3, 4], "b": [5, 6]}))
self.assertEqual(expected, actual)
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
def testStringRepresentation(self, tree, correct_string):
"""Checks that the string representation of a tree works."""
treedef = tree_util.tree_structure(tree)
self.assertRegex(str(treedef), correct_string)
def testTreeDefWithEmptyDictStringRepresentation(self):
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
@parameterized.parameters(*TREES)
def testPickleRoundTrip(self, tree):
leaves, treedef = tree_util.tree_flatten(tree)
treedef_restored = pickle.loads(pickle.dumps(treedef))
self.assertEqual(treedef, treedef_restored)
reconstituted = treedef_restored.unflatten(leaves)
self.assertEqual(tree, reconstituted)
def testDictKeysSortable(self):
d = {"a": 1, 2: "b"}
with self.assertRaisesRegex(TypeError, "'<' not supported"):
_, _ = tree_util.tree_flatten(d)
def testFlattenDictKeyOrder(self):
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
leaves, treedef = tree_util.tree_flatten(d)
self.assertEqual(leaves, [1, 2, 1, 2])
self.assertEqual(
str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})"
)
restored_d = tree_util.tree_unflatten(treedef, leaves)
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
def testWalk(self):
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
leaves, treedef = tree_util.tree_flatten(d)
nodes_visited = []
node_data_visited = []
leaves_visited = []
def f_node(node, node_data):
nodes_visited.append(node)
node_data_visited.append(node_data)
def f_leaf(leaf):
leaves_visited.append(leaf)
treedef.walk(f_node, f_leaf, leaves)
self.assertEqual(leaves_visited, [1, 2, 1, 2])
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
@parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES))
def testRoundtripWithPath(self, inputs):
key_leaves, treedef = tree_util.tree_flatten_with_path(inputs)
actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves])
self.assertEqual(actual, inputs)
def testTreeMapWithPath(self):
tree = [{i: i for i in range(10)}]
all_zeros = tree_util.tree_map_with_path(
lambda kp, val: val - kp[1].key + kp[0].idx, tree
)
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
def testTreeMapWithPathMultipleTrees(self):
tree1 = [AnObject2(x=12,
y={'cin': [1, 4, 10], 'bar': None},
z='constantdef'),
5]
tree2 = [AnObject2(x=2,
y={'cin': [2, 2, 2], 'bar': None},
z='constantdef'),
2]
from_two_trees = tree_util.tree_map_with_path(
lambda kp, a, b: a + b, tree1, tree2
)
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)
def testTreeLeavesWithPath(self):
tree = [{i: i for i in range(10)}]
actual = tree_util.tree_leaves_with_path(tree)
expected = [((tree_util.SequenceKey(0), tree_util.DictKey(i)), i)
for i in range(10)]
self.assertEqual(actual, expected)
def testKeyStr(self):
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
flattened, _ = tree_util.tree_flatten_with_path(tree1)
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
self.assertEqual(
strs,
[
"[0].foo: 12",
"[0].bar['cin'][0]: 1",
"[0].bar['cin'][1]: 4",
"[0].bar['cin'][2]: 10",
"[1]: [0 1 2 3 4]",
],
)
def testTreeMapWithPathWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
out = tree_util.tree_map_with_path(
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((0, 1, [3]),
(0, 2, jnp.array(0))),
(1, [3, 4, 5], ([0], 7, [5, 6]))))
def testFlattenWithPathWithIsLeafArgument(self):
def is_empty(x):
try:
children, _ = flatten_one_level(x)
except ValueError:
return True # Cannot flatten x; means it must be a leaf
return len(children) == 0
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
self.assertEqual(
strs,
[
"['a']: 1",
"['obj']x: EmptyTuple()",
"['obj']y: 0",
"['sub'][0]: [1 2]",
"['sub'][1].foo: ()",
"['sub'][1].bar[0]: None",
],
)
def testFlattenOneLevel(self):
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
self.assertEqual(flatten_one_level(tree1["sub"])[0],
tree1["sub"])
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
[(), [None]])
self.assertEqual(flatten_one_level(tree1["obj"])[0],
[EmptyTuple(), 0])
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(1)
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(jnp.array((1, 2)))
def testOptionalFlatten(self):
@tree_util.register_pytree_with_keys_class
class FooClass:
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten(self):
return ((self.x, self.y), 'treedef')
def tree_flatten_with_keys(self):
return (((tree_util.GetAttrKey('x'), self.x),
(tree_util.GetAttrKey('x'), self.y)), 'treedef')
@classmethod
def tree_unflatten(cls, _, children):
return cls(*children)
tree = FooClass(x=1, y=2)
self.assertEqual(
str(tree_util.tree_flatten(tree)[1]),
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
)
self.assertEqual(
str(tree_util.tree_flatten_with_path(tree)[1]),
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
)
self.assertEqual(tree_util.tree_flatten(tree)[0],
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
def testPyTreeWithoutKeysIsntTreatedAsLeaf(self):
leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4]))
self.assertLen(leaves, 4)
def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
self.assertLen(leaves, 1)
class RavelUtilTest(jtu.JaxTestCase):
def testFloats(self):
tree = [
jnp.array([3.0], jnp.float32),
jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32),
]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testInts(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.int32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedFloatInt(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedIntBool(self):
tree = [jnp.array([0], jnp.bool_),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedFloatComplex(self):
tree = [jnp.array([1.], jnp.float32),
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testEmpty(self):
tree = []
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32) # convention
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testDtypePolymorphicUnravel(self):
# https://github.com/google/jax/issues/7809
x = jnp.arange(10, dtype=jnp.float32)
x_flat, unravel = flatten_util.ravel_pytree(x)
y = x_flat < 5.3
x_ = unravel(y)
self.assertEqual(x_.dtype, y.dtype)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testDtypeMonomorphicUnravel(self):
# https://github.com/google/jax/issues/7809
x1 = jnp.arange(10, dtype=jnp.float32)
x2 = jnp.arange(10, dtype=jnp.int32)
x_flat, unravel = flatten_util.ravel_pytree((x1, x2))
y = x_flat < 5.3
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
_ = unravel(y)
def test_no_recompile(self):
x1 = jnp.array([1, 2])
x2 = jnp.array([3, 4])
x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2))
x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2))
num_traces = 0
def run(flat, unravel):
nonlocal num_traces
num_traces += 1
flat = flat + 1
return unravel(flat)
run = jax.jit(run, static_argnums=1)
run(x_flat1, unravel1)
run(x_flat2, unravel2)
self.assertEqual(num_traces, 1)
class TreePrefixErrorsTest(jtu.JaxTestCase):
def test_different_types(self):
e, = prefix_errors((1, 2), [1, 2])
expected = ("pytree structure error: different types at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_types_nested(self):
e, = prefix_errors(((1,), (2,)), ([3], (4,)))
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_types_multiple(self):
e1, e2 = prefix_errors(((1,), (2,)), ([3], [4]))
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')
def test_different_num_children_tuple(self):
e, = prefix_errors((1,), (2, 3))
expected = ("pytree structure error: different lengths of tuple "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_list(self):
e, = prefix_errors([1], [2, 3])
expected = ("pytree structure error: different lengths of list "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_generic(self):
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_nested(self):
e, = prefix_errors([[1]], [[2, 3]])
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_multiple(self):
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')
def test_different_num_children_print_key_diff(self):
e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3})
expected = ("so the symmetric difference on key sets is\n"
" b")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_metadata(self):
e, = prefix_errors({1: 2}, {3: 4})
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_metadata_nested(self):
e, = prefix_errors([{1: 2}], [{3: 4}])
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_metadata_multiple(self):
e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')
def test_fallback_keypath(self):
e, = prefix_errors(Special(1, [2]), Special(3, 4))
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[<flat index 1>\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_no_errors(self):
() = prefix_errors((1, 2), ((11, 12, 13), 2))
def test_different_structure_no_children(self):
e, = prefix_errors({}, {'a': []})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())