Add tests to cover PyTreeDef.flatten_up_to error scenarios.

Also improve coverage of `PyTreeDef.flatten_up_to` success scenarios.

PiperOrigin-RevId: 570152827
This commit is contained in:
jax authors 2023-10-02 13:00:19 -07:00
parent c60b4aef06
commit d45fa22424

View File

@ -318,12 +318,114 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(flat2, list(range(10)))
self.assertEqual(flat3, list(range(10)))
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
out = tree.flatten_up_to([({
"foo": 7
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
@parameterized.parameters(
(
[(1, 2), None, ATuple(foo=3, bar=7)],
[({"foo": 7}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)],
[{"foo": 7}, (3, 4), (11, 9), None],
),
([1], [{"a": 7}], [{"a": 7}]),
([1], [[7]], [[7]]),
([1], [(7,)], [(7,)]),
((1, 2), ({"a": 7}, {"a": 8}), [{"a": 7}, {"a": 8}]),
((1,), ([7],), [[7]]),
((1,), ((7,),), [(7,)]),
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
({"a": 1}, {"a": (7,)}, [(7,)]),
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
)
def testFlattenUpTo(self, tree, xs, expected):
_, tree_def = tree_util.tree_flatten(tree)
out = tree_def.flatten_up_to(xs)
self.assertEqual(out, expected)
@parameterized.parameters(
([1, 2], [7], re.escape("List arity mismatch: 1 != 2; list: [7].")),
((1,), (7, 8), re.escape("Tuple arity mismatch: 2 != 1; tuple: (7, 8).")),
(
{"a": 1},
{"a": 7, "b": 8},
re.escape(
"Dict key mismatch; expected keys: ['a']; dict: {'a': 7, 'b': 8}."
),
),
(
{"a": 1},
{"b": 7},
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
),
([1], {"a": 7}, re.escape("Expected list, got {'a': 7}.")),
([1], (7,), re.escape("Expected list, got (7,).")),
((1,), [7], re.escape("Expected tuple, got [7].")),
((1,), {"b": 7}, re.escape("Expected tuple, got {'b': 7}.")),
({"a": 1}, (7,), re.escape("Expected dict, got (7,).")),
({"a": 1}, [7], re.escape("Expected dict, got [7].")),
([[1]], [7], re.escape("Expected list, got 7.")),
([[1]], [(7,)], re.escape("Expected list, got (7,).")),
([[1]], [{"a": 7}], re.escape("Expected list, got {'a': 7}.")),
([(1,)], [7], re.escape("Expected tuple, got 7.")),
([(1,)], [[7]], re.escape("Expected tuple, got [7].")),
([(1,)], [{"a": 7}], re.escape("Expected tuple, got {'a': 7}.")),
([{"a": 1}], [7], re.escape("Expected dict, got 7.")),
([{"a": 1}], [[7]], re.escape("Expected dict, got [7].")),
([{"a": 1}], [(7,)], re.escape("Expected dict, got (7,).")),
(
[{"a": 1}],
[{"b": 7}],
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
),
(([1],), (7,), re.escape("Expected list, got 7.")),
(([1],), ((7,),), re.escape("Expected list, got (7,).")),
(([1],), ({"a": 7},), re.escape("Expected list, got {'a': 7}.")),
(((1,),), (7,), re.escape("Expected tuple, got 7.")),
(((1,),), ([7],), re.escape("Expected tuple, got [7].")),
(((1,),), ({"a": 7},), re.escape("Expected tuple, got {'a': 7}.")),
(({"a": 1},), (7,), re.escape("Expected dict, got 7.")),
(({"a": 1},), ([7],), re.escape("Expected dict, got [7].")),
(({"a": 1},), ((7,),), re.escape("Expected dict, got (7,).")),
(
({"a": 1},),
({"b": 7},),
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
),
({"a": [1]}, {"a": 7}, re.escape("Expected list, got 7.")),
({"a": [1]}, {"a": (7,)}, re.escape("Expected list, got (7,).")),
({"a": [1]}, {"a": {"a": 7}}, re.escape("Expected list, got {'a': 7}.")),
({"a": (1,)}, {"a": 7}, re.escape("Expected tuple, got 7.")),
({"a": (1,)}, {"a": [7]}, re.escape("Expected tuple, got [7].")),
(
{"a": (1,)},
{"a": {"a": 7}},
re.escape("Expected tuple, got {'a': 7}."),
),
({"a": {"a": 1}}, {"a": 7}, re.escape("Expected dict, got 7.")),
({"a": {"a": 1}}, {"a": [7]}, re.escape("Expected dict, got [7].")),
({"a": {"a": 1}}, {"a": (7,)}, re.escape("Expected dict, got (7,).")),
(
{"a": {"a": 1}},
{"a": {"b": 7}},
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
),
(
[ATuple(foo=1, bar=2)],
[(1, 2)],
re.escape("Expected named tuple, got (1, 2)."),
),
(
[ATuple(foo=1, bar=2)],
[ATuple2(foo=1, bar=2)],
re.escape("Named tuple type mismatch"),
),
(
[AnObject(x=[1], y=(2,), z={"a": [1]})],
[([1], (2,), {"a": [1]})],
re.escape("Custom node type mismatch"),
),
)
def testFlattenUpToErrors(self, tree, xs, error):
_, tree_def = tree_util.tree_flatten(tree)
with self.assertRaisesRegex(ValueError, error):
tree_def.flatten_up_to(xs)
def testTreeMap(self):
x = ((1, 2), (3, 4, 5))