mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
1021 lines
36 KiB
Python
1021 lines
36 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
import dataclasses
|
|
import functools
|
|
import pickle
|
|
import re
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
from jax import tree_util
|
|
from jax import flatten_util
|
|
from jax._src import test_util as jtu
|
|
from jax._src.tree_util import prefix_errors, flatten_one_level
|
|
import jax.numpy as jnp
|
|
|
|
|
|
def _dummy_func(*args, **kwargs):
|
|
return
|
|
|
|
|
|
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
|
|
|
|
class ANamedTupleSubclass(ATuple):
|
|
pass
|
|
|
|
ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
|
|
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
|
|
lambda bar, foo: ATuple2(foo[0], bar))
|
|
|
|
class AnObject:
|
|
|
|
def __init__(self, x, y, z):
|
|
self.x = x
|
|
self.y = y
|
|
self.z = z
|
|
|
|
def __eq__(self, other):
|
|
return self.x == other.x and self.y == other.y and self.z == other.z
|
|
|
|
def __hash__(self):
|
|
return hash((self.x, self.y, self.z))
|
|
|
|
def __repr__(self):
|
|
return f"AnObject({self.x},{self.y},{self.z})"
|
|
|
|
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
|
|
lambda z, xy: AnObject(xy[0], xy[1], z))
|
|
|
|
class AnObject2(AnObject): pass
|
|
|
|
tree_util.register_pytree_with_keys(
|
|
AnObject2,
|
|
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
|
|
lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved)
|
|
)
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class Special:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __repr__(self):
|
|
return f"Special(x={self.x}, y={self.y})"
|
|
|
|
def tree_flatten(self):
|
|
return ((self.x, self.y), None)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(*children)
|
|
|
|
def __eq__(self, other):
|
|
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
|
|
|
|
|
|
@tree_util.register_pytree_with_keys_class
|
|
class SpecialWithKeys(Special):
|
|
def tree_flatten_with_keys(self):
|
|
return (((tree_util.GetAttrKey('x'), self.x),
|
|
(tree_util.GetAttrKey('y'), self.y)), None)
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class FlatCache:
|
|
def __init__(self, structured, *, leaves=None, treedef=None):
|
|
if treedef is None:
|
|
leaves, treedef = tree_util.tree_flatten(structured)
|
|
self._structured = structured
|
|
self.treedef = treedef
|
|
self.leaves = leaves
|
|
|
|
def __hash__(self):
|
|
return hash(self.structured)
|
|
|
|
def __eq__(self, other):
|
|
return self.structured == other.structured
|
|
|
|
def __repr__(self):
|
|
return f"FlatCache({self.structured!r})"
|
|
|
|
@property
|
|
def structured(self):
|
|
if self._structured is None:
|
|
self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
|
|
return self._structured
|
|
|
|
def tree_flatten(self):
|
|
return self.leaves, self.treedef
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, meta, data):
|
|
if not tree_util.all_leaves(data):
|
|
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
|
|
return FlatCache(None, leaves=data, treedef=meta)
|
|
|
|
|
|
@tree_util.register_static
|
|
class StaticInt(int):
|
|
pass
|
|
|
|
|
|
@tree_util.register_static
|
|
class StaticTuple(tuple):
|
|
pass
|
|
|
|
|
|
@tree_util.register_static
|
|
class StaticDict(dict):
|
|
pass
|
|
|
|
|
|
@tree_util.register_static
|
|
@dataclasses.dataclass
|
|
class BlackBox:
|
|
"""Stores a value but pretends to be equal to every other black box."""
|
|
|
|
value: int
|
|
|
|
def __hash__(self):
|
|
return 0
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, BlackBox)
|
|
|
|
|
|
TREES = (
|
|
(None,),
|
|
((None,),),
|
|
((),),
|
|
(([()]),),
|
|
((1, 2),),
|
|
(((1, "foo"), ["bar", (3, None, 7)]),),
|
|
([3],),
|
|
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
|
([AnObject(3, None, [4, "foo"])],),
|
|
([AnObject2(3, None, [4, "foo"])],),
|
|
(Special(2, 3.),),
|
|
({"a": 1, "b": 2},),
|
|
(StaticInt(1),),
|
|
(StaticTuple((2, 3)),),
|
|
(StaticDict(foo=4, bar=5),),
|
|
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(collections.defaultdict(dict,
|
|
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
|
(FlatCache(None),),
|
|
(FlatCache(1),),
|
|
(FlatCache({"a": [1, 2]}),),
|
|
(BlackBox(value=2),),
|
|
)
|
|
|
|
|
|
TREE_STRINGS = (
|
|
"PyTreeDef(None)",
|
|
"PyTreeDef((None,))",
|
|
"PyTreeDef(())",
|
|
"PyTreeDef([()])",
|
|
"PyTreeDef((*, *))",
|
|
"PyTreeDef(((*, *), [*, (*, None, *)]))",
|
|
"PyTreeDef([*])",
|
|
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
|
|
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
|
|
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
|
|
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
|
|
"PyTreeDef(CustomNode(Special[None], [*, *]))",
|
|
"PyTreeDef({'a': *, 'b': *})",
|
|
"PyTreeDef(CustomNode(StaticInt[1], []))",
|
|
"PyTreeDef(CustomNode(StaticTuple[(2, 3)], []))",
|
|
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
|
|
)
|
|
|
|
# pytest expects "tree_util_test.ATuple"
|
|
STRS = []
|
|
for tree_str in TREE_STRINGS:
|
|
tree_str = re.escape(tree_str)
|
|
tree_str = tree_str.replace("__main__", ".*")
|
|
STRS.append(tree_str)
|
|
TREE_STRINGS = STRS
|
|
|
|
LEAVES = (
|
|
("foo",),
|
|
(0.1,),
|
|
(1,),
|
|
(object(),),
|
|
)
|
|
|
|
# All except those decorated by register_pytree_node_class
|
|
TREES_WITH_KEYPATH = (
|
|
(None,),
|
|
((None,),),
|
|
((),),
|
|
(([()]),),
|
|
((1, 0),),
|
|
(((1, "foo"), ["bar", (3, None, 7)]),),
|
|
([3],),
|
|
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
|
([AnObject2(3, None, [4, "foo"])],),
|
|
(SpecialWithKeys(2, 3.),),
|
|
({"a": 1, "b": 0},),
|
|
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(collections.defaultdict(dict,
|
|
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
|
(StaticInt(1),),
|
|
(StaticTuple((2, 3)),),
|
|
(StaticDict(foo=4, bar=5),),
|
|
(BlackBox(value=2),),
|
|
)
|
|
|
|
|
|
class TreeTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
|
def testRoundtrip(self, inputs):
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
self.assertEqual(actual, inputs)
|
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
|
def testRoundtripWithFlattenUpTo(self, inputs):
|
|
_, tree = tree_util.tree_flatten(inputs)
|
|
xs = tree.flatten_up_to(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
self.assertEqual(actual, inputs)
|
|
|
|
@parameterized.parameters(
|
|
(tree_util.Partial(_dummy_func),),
|
|
(tree_util.Partial(_dummy_func, 1, 2),),
|
|
(tree_util.Partial(_dummy_func, x="a"),),
|
|
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
|
|
)
|
|
def testRoundtripPartial(self, inputs):
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
# functools.partial does not support equality comparisons:
|
|
# https://stackoverflow.com/a/32786109/809705
|
|
self.assertEqual(actual.func, inputs.func)
|
|
self.assertEqual(actual.args, inputs.args)
|
|
self.assertEqual(actual.keywords, inputs.keywords)
|
|
|
|
def testPartialDoesNotMergeWithOtherPartials(self):
|
|
def f(a, b, c): pass
|
|
g = functools.partial(f, 2)
|
|
h = tree_util.Partial(g, 3)
|
|
self.assertEqual(h.args, (3,))
|
|
|
|
def testPartialFuncAttributeHasStableHash(self):
|
|
# https://github.com/google/jax/issues/9429
|
|
fun = functools.partial(print, 1)
|
|
p1 = tree_util.Partial(fun, 2)
|
|
p2 = tree_util.Partial(fun, 2)
|
|
self.assertEqual(fun, p1.func)
|
|
self.assertEqual(p1.func, fun)
|
|
self.assertEqual(p1.func, p2.func)
|
|
self.assertEqual(hash(p1.func), hash(p2.func))
|
|
|
|
def testChildren(self):
|
|
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
|
|
_, c0 = tree_util.tree_flatten((0, 0, 0))
|
|
_, c1 = tree_util.tree_flatten((7,))
|
|
self.assertEqual([c0, c1], tree.children())
|
|
|
|
def testTreedefTupleFromChildren(self):
|
|
# https://github.com/google/jax/issues/7377
|
|
tree = ((1, 2, (3, 4)), (5,))
|
|
leaves, treedef1 = tree_util.tree_flatten(tree)
|
|
treedef2 = tree_util.treedef_tuple(treedef1.children())
|
|
self.assertEqual(treedef1.num_leaves, len(leaves))
|
|
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
|
|
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
|
|
|
|
def testTreedefTupleComparesEqual(self):
|
|
# https://github.com/google/jax/issues/9066
|
|
self.assertEqual(tree_util.tree_structure((3,)),
|
|
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
|
|
|
|
def testFlattenOrder(self):
|
|
flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
|
flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
|
flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9])
|
|
self.assertEqual(flat1, list(range(10)))
|
|
self.assertEqual(flat2, list(range(10)))
|
|
self.assertEqual(flat3, list(range(10)))
|
|
|
|
@parameterized.parameters(
|
|
(
|
|
[(1, 2), None, ATuple(foo=3, bar=7)],
|
|
[({"foo": 7}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)],
|
|
[{"foo": 7}, (3, 4), (11, 9), None],
|
|
),
|
|
([1], [{"a": 7}], [{"a": 7}]),
|
|
([1], [[7]], [[7]]),
|
|
([1], [(7,)], [(7,)]),
|
|
((1, 2), ({"a": 7}, {"a": 8}), [{"a": 7}, {"a": 8}]),
|
|
((1,), ([7],), [[7]]),
|
|
((1,), ((7,),), [(7,)]),
|
|
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
|
|
({"a": 1}, {"a": (7,)}, [(7,)]),
|
|
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
|
|
)
|
|
def testFlattenUpTo(self, tree, xs, expected):
|
|
_, tree_def = tree_util.tree_flatten(tree)
|
|
out = tree_def.flatten_up_to(xs)
|
|
self.assertEqual(out, expected)
|
|
|
|
@parameterized.parameters(
|
|
([1, 2], [7], re.escape("List arity mismatch: 1 != 2; list: [7].")),
|
|
((1,), (7, 8), re.escape("Tuple arity mismatch: 2 != 1; tuple: (7, 8).")),
|
|
(
|
|
{"a": 1},
|
|
{"a": 7, "b": 8},
|
|
re.escape(
|
|
"Dict key mismatch; expected keys: ['a']; dict: {'a': 7, 'b': 8}."
|
|
),
|
|
),
|
|
(
|
|
{"a": 1},
|
|
{"b": 7},
|
|
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
|
),
|
|
([1], {"a": 7}, re.escape("Expected list, got {'a': 7}.")),
|
|
([1], (7,), re.escape("Expected list, got (7,).")),
|
|
((1,), [7], re.escape("Expected tuple, got [7].")),
|
|
((1,), {"b": 7}, re.escape("Expected tuple, got {'b': 7}.")),
|
|
({"a": 1}, (7,), re.escape("Expected dict, got (7,).")),
|
|
({"a": 1}, [7], re.escape("Expected dict, got [7].")),
|
|
([[1]], [7], re.escape("Expected list, got 7.")),
|
|
([[1]], [(7,)], re.escape("Expected list, got (7,).")),
|
|
([[1]], [{"a": 7}], re.escape("Expected list, got {'a': 7}.")),
|
|
([(1,)], [7], re.escape("Expected tuple, got 7.")),
|
|
([(1,)], [[7]], re.escape("Expected tuple, got [7].")),
|
|
([(1,)], [{"a": 7}], re.escape("Expected tuple, got {'a': 7}.")),
|
|
([{"a": 1}], [7], re.escape("Expected dict, got 7.")),
|
|
([{"a": 1}], [[7]], re.escape("Expected dict, got [7].")),
|
|
([{"a": 1}], [(7,)], re.escape("Expected dict, got (7,).")),
|
|
(
|
|
[{"a": 1}],
|
|
[{"b": 7}],
|
|
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
|
),
|
|
(([1],), (7,), re.escape("Expected list, got 7.")),
|
|
(([1],), ((7,),), re.escape("Expected list, got (7,).")),
|
|
(([1],), ({"a": 7},), re.escape("Expected list, got {'a': 7}.")),
|
|
(((1,),), (7,), re.escape("Expected tuple, got 7.")),
|
|
(((1,),), ([7],), re.escape("Expected tuple, got [7].")),
|
|
(((1,),), ({"a": 7},), re.escape("Expected tuple, got {'a': 7}.")),
|
|
(({"a": 1},), (7,), re.escape("Expected dict, got 7.")),
|
|
(({"a": 1},), ([7],), re.escape("Expected dict, got [7].")),
|
|
(({"a": 1},), ((7,),), re.escape("Expected dict, got (7,).")),
|
|
(
|
|
({"a": 1},),
|
|
({"b": 7},),
|
|
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
|
),
|
|
({"a": [1]}, {"a": 7}, re.escape("Expected list, got 7.")),
|
|
({"a": [1]}, {"a": (7,)}, re.escape("Expected list, got (7,).")),
|
|
({"a": [1]}, {"a": {"a": 7}}, re.escape("Expected list, got {'a': 7}.")),
|
|
({"a": (1,)}, {"a": 7}, re.escape("Expected tuple, got 7.")),
|
|
({"a": (1,)}, {"a": [7]}, re.escape("Expected tuple, got [7].")),
|
|
(
|
|
{"a": (1,)},
|
|
{"a": {"a": 7}},
|
|
re.escape("Expected tuple, got {'a': 7}."),
|
|
),
|
|
({"a": {"a": 1}}, {"a": 7}, re.escape("Expected dict, got 7.")),
|
|
({"a": {"a": 1}}, {"a": [7]}, re.escape("Expected dict, got [7].")),
|
|
({"a": {"a": 1}}, {"a": (7,)}, re.escape("Expected dict, got (7,).")),
|
|
(
|
|
{"a": {"a": 1}},
|
|
{"a": {"b": 7}},
|
|
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
|
),
|
|
(
|
|
[ATuple(foo=1, bar=2)],
|
|
[(1, 2)],
|
|
re.escape("Expected named tuple, got (1, 2)."),
|
|
),
|
|
(
|
|
[ATuple(foo=1, bar=2)],
|
|
[ATuple2(foo=1, bar=2)],
|
|
re.escape("Named tuple type mismatch"),
|
|
),
|
|
(
|
|
[AnObject(x=[1], y=(2,), z={"a": [1]})],
|
|
[([1], (2,), {"a": [1]})],
|
|
re.escape("Custom node type mismatch"),
|
|
),
|
|
)
|
|
def testFlattenUpToErrors(self, tree, xs, error):
|
|
_, tree_def = tree_util.tree_flatten(tree)
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
tree_def.flatten_up_to(xs)
|
|
|
|
def testTreeMap(self):
|
|
x = ((1, 2), (3, 4, 5))
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
|
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
|
|
|
def testTreeMapWithIsLeafArgument(self):
|
|
x = ((1, 2), [3, 4, 5])
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
|
|
is_leaf=lambda n: isinstance(n, list))
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
|
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
|
|
|
def testTreeReduceWithIsLeafArgument(self):
|
|
out = tree_util.tree_reduce(lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]],
|
|
is_leaf=lambda l: isinstance(l, tuple))
|
|
self.assertEqual(out, (1, 2, 3, 4, 5, 6))
|
|
|
|
@parameterized.parameters(
|
|
tree_util.tree_leaves,
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
|
|
def testFlattenIsLeaf(self, leaf_fn):
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
|
leaves = leaf_fn(x, is_leaf=lambda t: False)
|
|
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(leaves, x)
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
|
|
self.assertEqual(leaves, [x])
|
|
leaves = leaf_fn(x, is_leaf=lambda t: True)
|
|
self.assertEqual(leaves, [x])
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
|
leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(leaves, [(1,), (2,), (3,)])
|
|
|
|
@parameterized.parameters(
|
|
tree_util.tree_structure,
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
|
|
def testStructureIsLeaf(self, structure_fn):
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
|
treedef = structure_fn(x, is_leaf=lambda t: False)
|
|
self.assertEqual(treedef.num_leaves, 6)
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
treedef = structure_fn(x, is_leaf=lambda t: True)
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
|
treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testRoundtripIsLeaf(self, tree):
|
|
xs, treedef = tree_util.tree_flatten(
|
|
tree, is_leaf=lambda t: isinstance(t, tuple))
|
|
recon_tree = tree_util.tree_unflatten(treedef, xs)
|
|
self.assertEqual(recon_tree, tree)
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testAllLeavesWithTrees(self, tree):
|
|
leaves = tree_util.tree_leaves(tree)
|
|
self.assertTrue(tree_util.all_leaves(leaves))
|
|
self.assertFalse(tree_util.all_leaves([tree]))
|
|
|
|
@parameterized.parameters(*LEAVES)
|
|
def testAllLeavesWithLeaves(self, leaf):
|
|
self.assertTrue(tree_util.all_leaves([leaf]))
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testCompose(self, tree):
|
|
treedef = tree_util.tree_structure(tree)
|
|
inner_treedef = tree_util.tree_structure(["*", "*", "*"])
|
|
composed_treedef = treedef.compose(inner_treedef)
|
|
expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
|
|
self.assertEqual(composed_treedef.num_leaves, expected_leaves)
|
|
expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
|
|
(inner_treedef.num_nodes * treedef.num_leaves))
|
|
self.assertEqual(composed_treedef.num_nodes, expected_nodes)
|
|
leaves = [1] * expected_leaves
|
|
composed = tree_util.tree_unflatten(composed_treedef, leaves)
|
|
self.assertEqual(leaves, tree_util.tree_leaves(composed))
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testTranspose(self, tree):
|
|
outer_treedef = tree_util.tree_structure(tree)
|
|
if not outer_treedef.num_leaves:
|
|
self.skipTest("Skipping empty tree")
|
|
inner_treedef = tree_util.tree_structure([1, 1, 1])
|
|
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
|
|
self.assertEqual(actual, [tree, tree, tree])
|
|
|
|
def testTransposeMismatchOuter(self):
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
def testTransposeMismatchInner(self):
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
|
|
inner_treedef = tree_util.tree_structure([1, 2, 3])
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
def testTransposeWithCustomObject(self):
|
|
outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
|
|
FlatCache({"a": [3, 4], "b": [5, 6]}))
|
|
self.assertEqual(expected, actual)
|
|
|
|
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
|
|
def testStringRepresentation(self, tree, correct_string):
|
|
"""Checks that the string representation of a tree works."""
|
|
treedef = tree_util.tree_structure(tree)
|
|
self.assertRegex(str(treedef), correct_string)
|
|
|
|
def testTreeDefWithEmptyDictStringRepresentation(self):
|
|
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testPickleRoundTrip(self, tree):
|
|
leaves, treedef = tree_util.tree_flatten(tree)
|
|
treedef_restored = pickle.loads(pickle.dumps(treedef))
|
|
self.assertEqual(treedef, treedef_restored)
|
|
reconstituted = treedef_restored.unflatten(leaves)
|
|
self.assertEqual(tree, reconstituted)
|
|
|
|
def testDictKeysSortable(self):
|
|
d = {"a": 1, 2: "b"}
|
|
with self.assertRaisesRegex(TypeError, "'<' not supported"):
|
|
_, _ = tree_util.tree_flatten(d)
|
|
|
|
def testFlattenDictKeyOrder(self):
|
|
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
|
leaves, treedef = tree_util.tree_flatten(d)
|
|
self.assertEqual(leaves, [1, 2, 1, 2])
|
|
self.assertEqual(
|
|
str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})"
|
|
)
|
|
restored_d = tree_util.tree_unflatten(treedef, leaves)
|
|
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
|
|
|
|
def testWalk(self):
|
|
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
|
leaves, treedef = tree_util.tree_flatten(d)
|
|
|
|
nodes_visited = []
|
|
node_data_visited = []
|
|
leaves_visited = []
|
|
|
|
def f_node(node, node_data):
|
|
nodes_visited.append(node)
|
|
node_data_visited.append(node_data)
|
|
|
|
def f_leaf(leaf):
|
|
leaves_visited.append(leaf)
|
|
|
|
treedef.walk(f_node, f_leaf, leaves)
|
|
self.assertEqual(leaves_visited, [1, 2, 1, 2])
|
|
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
|
|
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
|
|
|
|
@parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES))
|
|
def testRoundtripWithPath(self, inputs):
|
|
key_leaves, treedef = tree_util.tree_flatten_with_path(inputs)
|
|
actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves])
|
|
self.assertEqual(actual, inputs)
|
|
|
|
def testTreeMapWithPath(self):
|
|
tree = [{i: i for i in range(10)}]
|
|
all_zeros = tree_util.tree_map_with_path(
|
|
lambda kp, val: val - kp[1].key + kp[0].idx, tree
|
|
)
|
|
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
|
|
|
|
def testTreeMapWithPathMultipleTrees(self):
|
|
tree1 = [AnObject2(x=12,
|
|
y={'cin': [1, 4, 10], 'bar': None},
|
|
z='constantdef'),
|
|
5]
|
|
tree2 = [AnObject2(x=2,
|
|
y={'cin': [2, 2, 2], 'bar': None},
|
|
z='constantdef'),
|
|
2]
|
|
from_two_trees = tree_util.tree_map_with_path(
|
|
lambda kp, a, b: a + b, tree1, tree2
|
|
)
|
|
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
|
|
self.assertEqual(from_two_trees, from_one_tree)
|
|
|
|
def testTreeLeavesWithPath(self):
|
|
tree = [{i: i for i in range(10)}]
|
|
actual = tree_util.tree_leaves_with_path(tree)
|
|
expected = [((tree_util.SequenceKey(0), tree_util.DictKey(i)), i)
|
|
for i in range(10)]
|
|
self.assertEqual(actual, expected)
|
|
|
|
def testKeyStr(self):
|
|
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
|
|
flattened, _ = tree_util.tree_flatten_with_path(tree1)
|
|
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
|
self.assertEqual(
|
|
strs,
|
|
[
|
|
"[0].foo: 12",
|
|
"[0].bar['cin'][0]: 1",
|
|
"[0].bar['cin'][1]: 4",
|
|
"[0].bar['cin'][2]: 10",
|
|
"[1]: [0 1 2 3 4]",
|
|
],
|
|
)
|
|
|
|
def testTreeMapWithPathWithIsLeafArgument(self):
|
|
x = ((1, 2), [3, 4, 5])
|
|
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
|
|
out = tree_util.tree_map_with_path(
|
|
lambda kp, *xs: (kp[0].idx, *xs), x, y,
|
|
is_leaf=lambda n: isinstance(n, list))
|
|
self.assertEqual(out, (((0, 1, [3]),
|
|
(0, 2, jnp.array(0))),
|
|
(1, [3, 4, 5], ([0], 7, [5, 6]))))
|
|
|
|
def testFlattenWithPathWithIsLeafArgument(self):
|
|
def is_empty(x):
|
|
try:
|
|
children, _ = flatten_one_level(x)
|
|
except ValueError:
|
|
return True # Cannot flatten x; means it must be a leaf
|
|
return len(children) == 0
|
|
|
|
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
|
tree1 = {'a': 1,
|
|
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
|
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
|
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
|
|
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
|
self.assertEqual(
|
|
strs,
|
|
[
|
|
"['a']: 1",
|
|
"['obj']x: EmptyTuple()",
|
|
"['obj']y: 0",
|
|
"['sub'][0]: [1 2]",
|
|
"['sub'][1].foo: ()",
|
|
"['sub'][1].bar[0]: None",
|
|
],
|
|
)
|
|
|
|
def testFlattenOneLevel(self):
|
|
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
|
tree1 = {'a': 1,
|
|
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
|
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
|
self.assertEqual(flatten_one_level(tree1["sub"])[0],
|
|
tree1["sub"])
|
|
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
|
|
[(), [None]])
|
|
self.assertEqual(flatten_one_level(tree1["obj"])[0],
|
|
[EmptyTuple(), 0])
|
|
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
|
flatten_one_level(1)
|
|
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
|
flatten_one_level(jnp.array((1, 2)))
|
|
|
|
def testOptionalFlatten(self):
|
|
@tree_util.register_pytree_with_keys_class
|
|
class FooClass:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
def tree_flatten(self):
|
|
return ((self.x, self.y), 'treedef')
|
|
def tree_flatten_with_keys(self):
|
|
return (((tree_util.GetAttrKey('x'), self.x),
|
|
(tree_util.GetAttrKey('x'), self.y)), 'treedef')
|
|
@classmethod
|
|
def tree_unflatten(cls, _, children):
|
|
return cls(*children)
|
|
|
|
tree = FooClass(x=1, y=2)
|
|
self.assertEqual(
|
|
str(tree_util.tree_flatten(tree)[1]),
|
|
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
|
)
|
|
self.assertEqual(
|
|
str(tree_util.tree_flatten_with_path(tree)[1]),
|
|
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
|
)
|
|
self.assertEqual(tree_util.tree_flatten(tree)[0],
|
|
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
|
|
|
|
def testPyTreeWithoutKeysIsntTreatedAsLeaf(self):
|
|
leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4]))
|
|
self.assertLen(leaves, 4)
|
|
|
|
def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
|
|
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
|
|
self.assertLen(leaves, 1)
|
|
|
|
|
|
class StaticTest(parameterized.TestCase):
|
|
|
|
@parameterized.parameters(
|
|
(StaticInt(2),),
|
|
(StaticTuple((2, None)),),
|
|
(StaticDict(foo=2),),
|
|
)
|
|
def test_trace_just_once_with_same_static(self, y):
|
|
num_called = 0
|
|
|
|
@jax.jit
|
|
def fn(x: int, static_y: StaticInt):
|
|
nonlocal num_called
|
|
num_called += 1
|
|
unstatic_y = type(static_y).__base__(static_y)
|
|
[y] = tree_util.tree_leaves(unstatic_y)
|
|
return x + y
|
|
|
|
fn(1, y)
|
|
fn(3, y)
|
|
self.assertEqual(num_called, 1)
|
|
|
|
def test_name(self):
|
|
self.assertEqual(StaticInt.__name__, "StaticInt")
|
|
self.assertEqual(BlackBox.__name__, "BlackBox")
|
|
|
|
@parameterized.parameters(
|
|
(StaticInt(2), StaticInt(4)),
|
|
(StaticTuple((2, None)), StaticTuple((4, None))),
|
|
(StaticDict(foo=2), StaticDict(foo=4)),
|
|
)
|
|
def test_trace_twice_with_different_static(self, y1, y2):
|
|
num_called = 0
|
|
|
|
@jax.jit
|
|
def fn(x: int, static_y: StaticInt):
|
|
nonlocal num_called
|
|
num_called += 1
|
|
unstatic_y = type(static_y).__base__(static_y)
|
|
[y] = tree_util.tree_leaves(unstatic_y)
|
|
return x + y
|
|
|
|
fn(1, y1)
|
|
fn(3, y2)
|
|
self.assertEqual(num_called, 2)
|
|
|
|
def test_trace_just_once_if_static_looks_constant(self):
|
|
num_called = 0
|
|
|
|
@jax.jit
|
|
def fn(x: int, static_y: BlackBox):
|
|
nonlocal num_called
|
|
num_called += 1
|
|
return x + static_y.value
|
|
|
|
self.assertEqual(fn(1, BlackBox(2)), 3)
|
|
self.assertEqual(fn(3, BlackBox(1)), 5)
|
|
self.assertEqual(num_called, 1)
|
|
|
|
|
|
class RavelUtilTest(jtu.JaxTestCase):
|
|
|
|
def testFloats(self):
|
|
tree = [
|
|
jnp.array([3.0], jnp.float32),
|
|
jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32),
|
|
]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.float32)
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testInts(self):
|
|
tree = [jnp.array([3], jnp.int32),
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.int32)
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
|
def testMixedFloatInt(self):
|
|
tree = [jnp.array([3], jnp.int32),
|
|
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
|
def testMixedIntBool(self):
|
|
tree = [jnp.array([0], jnp.bool_),
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
|
def testMixedFloatComplex(self):
|
|
tree = [jnp.array([1.], jnp.float32),
|
|
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testEmpty(self):
|
|
tree = []
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.float32) # convention
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testDtypePolymorphicUnravel(self):
|
|
# https://github.com/google/jax/issues/7809
|
|
x = jnp.arange(10, dtype=jnp.float32)
|
|
x_flat, unravel = flatten_util.ravel_pytree(x)
|
|
y = x_flat < 5.3
|
|
x_ = unravel(y)
|
|
self.assertEqual(x_.dtype, y.dtype)
|
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
|
def testDtypeMonomorphicUnravel(self):
|
|
# https://github.com/google/jax/issues/7809
|
|
x1 = jnp.arange(10, dtype=jnp.float32)
|
|
x2 = jnp.arange(10, dtype=jnp.int32)
|
|
x_flat, unravel = flatten_util.ravel_pytree((x1, x2))
|
|
y = x_flat < 5.3
|
|
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
|
|
_ = unravel(y)
|
|
|
|
def test_no_recompile(self):
|
|
x1 = jnp.array([1, 2])
|
|
x2 = jnp.array([3, 4])
|
|
x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2))
|
|
x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2))
|
|
num_traces = 0
|
|
|
|
def run(flat, unravel):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
flat = flat + 1
|
|
return unravel(flat)
|
|
|
|
run = jax.jit(run, static_argnums=1)
|
|
|
|
run(x_flat1, unravel1)
|
|
run(x_flat2, unravel2)
|
|
self.assertEqual(num_traces, 1)
|
|
|
|
|
|
class TreePrefixErrorsTest(jtu.JaxTestCase):
|
|
|
|
def test_different_types(self):
|
|
e, = prefix_errors((1, 2), [1, 2])
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_types_nested(self):
|
|
e, = prefix_errors(((1,), (2,)), ([3], (4,)))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_types_multiple(self):
|
|
e1, e2 = prefix_errors(((1,), (2,)), ([3], [4]))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_different_num_children_tuple(self):
|
|
e, = prefix_errors((1,), (2, 3))
|
|
expected = ("pytree structure error: different lengths of tuple "
|
|
"at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_num_children_list(self):
|
|
e, = prefix_errors([1], [2, 3])
|
|
expected = ("pytree structure error: different lengths of list "
|
|
"at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
|
|
def test_different_num_children_generic(self):
|
|
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_num_children_nested(self):
|
|
e, = prefix_errors([[1]], [[2, 3]])
|
|
expected = ("pytree structure error: different lengths of list "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_num_children_multiple(self):
|
|
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
|
|
expected = ("pytree structure error: different lengths of list "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different lengths of list "
|
|
"at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_different_num_children_print_key_diff(self):
|
|
e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3})
|
|
expected = ("so the symmetric difference on key sets is\n"
|
|
" b")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_metadata(self):
|
|
e, = prefix_errors({1: 2}, {3: 4})
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_metadata_nested(self):
|
|
e, = prefix_errors([{1: 2}], [{3: 4}])
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_metadata_multiple(self):
|
|
e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_fallback_keypath(self):
|
|
e, = prefix_errors(Special(1, [2]), Special(3, 4))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[<flat index 1>\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_no_errors(self):
|
|
() = prefix_errors((1, 2), ((11, 12, 13), 2))
|
|
|
|
def test_different_structure_no_children(self):
|
|
e, = prefix_errors({}, {'a': []})
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
" in_axes")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|