mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 18:16:05 +00:00
534 lines
19 KiB
Python
534 lines
19 KiB
Python
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
import functools
|
|
import re
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
from jax import tree_util
|
|
from jax import flatten_util
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import pytree as pytree
|
|
from jax._src.tree_util import _process_pytree, prefix_errors
|
|
import jax.numpy as jnp
|
|
|
|
|
|
pytree_version = getattr(pytree, "version", 0)
|
|
|
|
def _dummy_func(*args, **kwargs):
|
|
return
|
|
|
|
|
|
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
|
|
|
|
class ANamedTupleSubclass(ATuple):
|
|
pass
|
|
|
|
class AnObject(object):
|
|
|
|
def __init__(self, x, y, z):
|
|
self.x = x
|
|
self.y = y
|
|
self.z = z
|
|
|
|
def __eq__(self, other):
|
|
return self.x == other.x and self.y == other.y and self.z == other.z
|
|
|
|
def __hash__(self):
|
|
return hash((self.x, self.y, self.z))
|
|
|
|
def __repr__(self):
|
|
return "AnObject({},{},{})".format(self.x, self.y, self.z)
|
|
|
|
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
|
|
lambda z, xy: AnObject(xy[0], xy[1], z))
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class Special:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __repr__(self):
|
|
return "Special(x={}, y={})".format(self.x, self.y)
|
|
|
|
def tree_flatten(self):
|
|
return ((self.x, self.y), None)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(*children)
|
|
|
|
def __eq__(self, other):
|
|
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class FlatCache:
|
|
def __init__(self, structured, *, leaves=None, treedef=None):
|
|
if treedef is None:
|
|
leaves, treedef = tree_util.tree_flatten(structured)
|
|
self._structured = structured
|
|
self.treedef = treedef
|
|
self.leaves = leaves
|
|
|
|
def __hash__(self):
|
|
return hash(self.structured)
|
|
|
|
def __eq__(self, other):
|
|
return self.structured == other.structured
|
|
|
|
def __repr__(self):
|
|
return f"FlatCache({self.structured!r})"
|
|
|
|
@property
|
|
def structured(self):
|
|
if self._structured is None:
|
|
self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
|
|
return self._structured
|
|
|
|
def tree_flatten(self):
|
|
return self.leaves, self.treedef
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, meta, data):
|
|
if not tree_util.all_leaves(data):
|
|
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
|
|
return FlatCache(None, leaves=data, treedef=meta)
|
|
|
|
TREES = (
|
|
(None,),
|
|
((None,),),
|
|
((),),
|
|
(([()]),),
|
|
((1, 2),),
|
|
(((1, "foo"), ["bar", (3, None, 7)]),),
|
|
([3],),
|
|
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
|
([AnObject(3, None, [4, "foo"])],),
|
|
(Special(2, 3.),),
|
|
({"a": 1, "b": 2},),
|
|
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(collections.defaultdict(dict,
|
|
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
|
(FlatCache(None),),
|
|
(FlatCache(1),),
|
|
(FlatCache({"a": [1, 2]}),),
|
|
)
|
|
|
|
|
|
TREE_STRINGS = (
|
|
"PyTreeDef(None)",
|
|
"PyTreeDef((None,))",
|
|
"PyTreeDef(())",
|
|
"PyTreeDef([()])",
|
|
"PyTreeDef((*, *))",
|
|
"PyTreeDef(((*, *), [*, (*, None, *)]))",
|
|
"PyTreeDef([*])",
|
|
("PyTreeDef([*, CustomNode(namedtuple[<class '__main__.ATuple'>], [(*, "
|
|
"CustomNode(namedtuple[<class '__main__.ATuple'>], [*, None])), {'baz': "
|
|
"*}])])"),
|
|
"PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])",
|
|
"PyTreeDef(CustomNode(<class '__main__.Special'>[None], [*, *]))",
|
|
"PyTreeDef({'a': *, 'b': *})",
|
|
)
|
|
|
|
# pytest expects "tree_util_test.ATuple"
|
|
STRS = []
|
|
for tree_str in TREE_STRINGS:
|
|
tree_str = re.escape(tree_str)
|
|
tree_str = tree_str.replace("__main__", ".*")
|
|
STRS.append(tree_str)
|
|
TREE_STRINGS = STRS
|
|
|
|
LEAVES = (
|
|
("foo",),
|
|
(0.1,),
|
|
(1,),
|
|
(object(),),
|
|
)
|
|
|
|
|
|
class TreeTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
|
def testRoundtrip(self, inputs):
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
self.assertEqual(actual, inputs)
|
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
|
def testRoundtripWithFlattenUpTo(self, inputs):
|
|
_, tree = tree_util.tree_flatten(inputs)
|
|
xs = tree.flatten_up_to(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
self.assertEqual(actual, inputs)
|
|
|
|
@parameterized.parameters(
|
|
(tree_util.Partial(_dummy_func),),
|
|
(tree_util.Partial(_dummy_func, 1, 2),),
|
|
(tree_util.Partial(_dummy_func, x="a"),),
|
|
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
|
|
)
|
|
def testRoundtripPartial(self, inputs):
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
# functools.partial does not support equality comparisons:
|
|
# https://stackoverflow.com/a/32786109/809705
|
|
self.assertEqual(actual.func, inputs.func)
|
|
self.assertEqual(actual.args, inputs.args)
|
|
self.assertEqual(actual.keywords, inputs.keywords)
|
|
|
|
def testPartialDoesNotMergeWithOtherPartials(self):
|
|
def f(a, b, c): pass
|
|
g = functools.partial(f, 2)
|
|
h = tree_util.Partial(g, 3)
|
|
self.assertEqual(h.args, (3,))
|
|
|
|
def testPartialFuncAttributeHasStableHash(self):
|
|
# https://github.com/google/jax/issues/9429
|
|
fun = functools.partial(print, 1)
|
|
p1 = tree_util.Partial(fun, 2)
|
|
p2 = tree_util.Partial(fun, 2)
|
|
self.assertEqual(fun, p1.func)
|
|
self.assertEqual(p1.func, fun)
|
|
self.assertEqual(p1.func, p2.func)
|
|
self.assertEqual(hash(p1.func), hash(p2.func))
|
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
|
def testRoundtripViaBuild(self, inputs):
|
|
xs, tree = _process_pytree(tuple, inputs)
|
|
actual = tree_util.build_tree(tree, xs)
|
|
self.assertEqual(actual, inputs)
|
|
|
|
def testChildren(self):
|
|
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
|
|
_, c0 = tree_util.tree_flatten((0, 0, 0))
|
|
_, c1 = tree_util.tree_flatten((7,))
|
|
self.assertEqual([c0, c1], tree.children())
|
|
|
|
def testTreedefTupleFromChildren(self):
|
|
# https://github.com/google/jax/issues/7377
|
|
tree = ((1, 2, (3, 4)), (5,))
|
|
leaves, treedef1 = tree_util.tree_flatten(tree)
|
|
treedef2 = tree_util.treedef_tuple(treedef1.children())
|
|
self.assertEqual(treedef1.num_leaves, len(leaves))
|
|
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
|
|
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
|
|
|
|
def testTreedefTupleComparesEqual(self):
|
|
# https://github.com/google/jax/issues/9066
|
|
self.assertEqual(tree_util.tree_structure((3,)),
|
|
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
|
|
|
|
def testFlattenUpTo(self):
|
|
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
|
out = tree.flatten_up_to([({
|
|
"foo": 7
|
|
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
|
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
|
|
|
def testTreeMultimap(self):
|
|
x = ((1, 2), (3, 4, 5))
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
|
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
|
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
|
|
|
def testTreeMultimapWithIsLeafArgument(self):
|
|
x = ((1, 2), [3, 4, 5])
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
|
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
|
|
is_leaf=lambda n: isinstance(n, list))
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
|
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
|
|
|
@parameterized.parameters(
|
|
tree_util.tree_leaves,
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
|
|
def testFlattenIsLeaf(self, leaf_fn):
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
|
leaves = leaf_fn(x, is_leaf=lambda t: False)
|
|
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(leaves, x)
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
|
|
self.assertEqual(leaves, [x])
|
|
leaves = leaf_fn(x, is_leaf=lambda t: True)
|
|
self.assertEqual(leaves, [x])
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
|
leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(leaves, [(1,), (2,), (3,)])
|
|
|
|
@parameterized.parameters(
|
|
tree_util.tree_structure,
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
|
|
def testStructureIsLeaf(self, structure_fn):
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
|
treedef = structure_fn(x, is_leaf=lambda t: False)
|
|
self.assertEqual(treedef.num_leaves, 6)
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
treedef = structure_fn(x, is_leaf=lambda t: True)
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
|
treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testRoundtripIsLeaf(self, tree):
|
|
xs, treedef = tree_util.tree_flatten(
|
|
tree, is_leaf=lambda t: isinstance(t, tuple))
|
|
recon_tree = tree_util.tree_unflatten(treedef, xs)
|
|
self.assertEqual(recon_tree, tree)
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testAllLeavesWithTrees(self, tree):
|
|
leaves = tree_util.tree_leaves(tree)
|
|
self.assertTrue(tree_util.all_leaves(leaves))
|
|
self.assertFalse(tree_util.all_leaves([tree]))
|
|
|
|
@parameterized.parameters(*LEAVES)
|
|
def testAllLeavesWithLeaves(self, leaf):
|
|
self.assertTrue(tree_util.all_leaves([leaf]))
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testCompose(self, tree):
|
|
treedef = tree_util.tree_structure(tree)
|
|
inner_treedef = tree_util.tree_structure(["*", "*", "*"])
|
|
composed_treedef = treedef.compose(inner_treedef)
|
|
expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
|
|
self.assertEqual(composed_treedef.num_leaves, expected_leaves)
|
|
expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
|
|
(inner_treedef.num_nodes * treedef.num_leaves))
|
|
self.assertEqual(composed_treedef.num_nodes, expected_nodes)
|
|
leaves = [1] * expected_leaves
|
|
composed = tree_util.tree_unflatten(composed_treedef, leaves)
|
|
self.assertEqual(leaves, tree_util.tree_leaves(composed))
|
|
|
|
@parameterized.parameters(*TREES)
|
|
def testTranspose(self, tree):
|
|
outer_treedef = tree_util.tree_structure(tree)
|
|
if not outer_treedef.num_leaves:
|
|
self.skipTest("Skipping empty tree")
|
|
inner_treedef = tree_util.tree_structure([1, 1, 1])
|
|
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
|
|
self.assertEqual(actual, [tree, tree, tree])
|
|
|
|
def testTransposeMismatchOuter(self):
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
def testTransposeMismatchInner(self):
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
|
|
inner_treedef = tree_util.tree_structure([1, 2, 3])
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
def testTransposeWithCustomObject(self):
|
|
outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
|
|
FlatCache({"a": [3, 4], "b": [5, 6]}))
|
|
self.assertEqual(expected, actual)
|
|
|
|
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
|
|
def testStringRepresentation(self, tree, correct_string):
|
|
"""Checks that the string representation of a tree works."""
|
|
treedef = tree_util.tree_structure(tree)
|
|
self.assertRegex(str(treedef), correct_string)
|
|
|
|
def testTreeDefWithEmptyDictStringRepresentation(self):
|
|
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
|
|
|
|
|
class RavelUtilTest(jtu.JaxTestCase):
|
|
|
|
def testFloats(self):
|
|
tree = [jnp.array([3.], jnp.float32),
|
|
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.float32)
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testInts(self):
|
|
tree = [jnp.array([3], jnp.int32),
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.int32)
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testMixedFloatInt(self):
|
|
tree = [jnp.array([3], jnp.int32),
|
|
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testMixedIntBool(self):
|
|
tree = [jnp.array([0], jnp.bool_),
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testMixedFloatComplex(self):
|
|
tree = [jnp.array([1.], jnp.float32),
|
|
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testEmpty(self):
|
|
tree = []
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
self.assertEqual(raveled.dtype, jnp.float32) # convention
|
|
tree_ = unravel(raveled)
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
def testDtypePolymorphicUnravel(self):
|
|
# https://github.com/google/jax/issues/7809
|
|
x = jnp.arange(10, dtype=jnp.float32)
|
|
x_flat, unravel = flatten_util.ravel_pytree(x)
|
|
y = x_flat < 5.3
|
|
x_ = unravel(y)
|
|
self.assertEqual(x_.dtype, y.dtype)
|
|
|
|
def testDtypeMonomorphicUnravel(self):
|
|
# https://github.com/google/jax/issues/7809
|
|
x1 = jnp.arange(10, dtype=jnp.float32)
|
|
x2 = jnp.arange(10, dtype=jnp.int32)
|
|
x_flat, unravel = flatten_util.ravel_pytree((x1, x2))
|
|
y = x_flat < 5.3
|
|
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
|
|
_ = unravel(y)
|
|
|
|
|
|
class TreePrefixErrorsTest(jtu.JaxTestCase):
|
|
|
|
def test_different_types(self):
|
|
e, = prefix_errors((1, 2), [1, 2])
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
" in_axes tree root")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_types_nested(self):
|
|
e, = prefix_errors(((1,), (2,)), ([3], (4,)))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_types_multiple(self):
|
|
e1, e2 = prefix_errors(((1,), (2,)), ([3], [4]))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_different_num_children(self):
|
|
e, = prefix_errors((1,), (2, 3))
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
" in_axes tree root")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_num_children_nested(self):
|
|
e, = prefix_errors([[1]], [[2, 3]])
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_num_children_multiple(self):
|
|
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
"at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_different_metadata(self):
|
|
e, = prefix_errors({1: 2}, {3: 4})
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
" in_axes tree root")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_metadata_nested(self):
|
|
e, = prefix_errors([{1: 2}], [{3: 4}])
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_different_metadata_multiple(self):
|
|
e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[0\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e1('in_axes')
|
|
expected = ("pytree structure error: different pytree metadata "
|
|
"at key path\n"
|
|
r" in_axes\[1\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e2('in_axes')
|
|
|
|
def test_fallback_keypath(self):
|
|
e, = prefix_errors(Special(1, [2]), Special(3, 4))
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
r" in_axes\[<flat index 1>\]")
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
raise e('in_axes')
|
|
|
|
def test_no_errors(self):
|
|
() = prefix_errors((1, 2), ((11, 12, 13), 2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|