mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add tests to cover PyTreeDef.flatten_up_to
error scenarios.
Also improve coverage of `PyTreeDef.flatten_up_to` success scenarios. PiperOrigin-RevId: 570152827
This commit is contained in:
parent
c60b4aef06
commit
d45fa22424
@ -318,12 +318,114 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(flat2, list(range(10)))
|
||||
self.assertEqual(flat3, list(range(10)))
|
||||
|
||||
def testFlattenUpTo(self):
|
||||
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
||||
out = tree.flatten_up_to([({
|
||||
"foo": 7
|
||||
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
||||
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
||||
@parameterized.parameters(
|
||||
(
|
||||
[(1, 2), None, ATuple(foo=3, bar=7)],
|
||||
[({"foo": 7}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)],
|
||||
[{"foo": 7}, (3, 4), (11, 9), None],
|
||||
),
|
||||
([1], [{"a": 7}], [{"a": 7}]),
|
||||
([1], [[7]], [[7]]),
|
||||
([1], [(7,)], [(7,)]),
|
||||
((1, 2), ({"a": 7}, {"a": 8}), [{"a": 7}, {"a": 8}]),
|
||||
((1,), ([7],), [[7]]),
|
||||
((1,), ((7,),), [(7,)]),
|
||||
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
|
||||
({"a": 1}, {"a": (7,)}, [(7,)]),
|
||||
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
|
||||
)
|
||||
def testFlattenUpTo(self, tree, xs, expected):
|
||||
_, tree_def = tree_util.tree_flatten(tree)
|
||||
out = tree_def.flatten_up_to(xs)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
([1, 2], [7], re.escape("List arity mismatch: 1 != 2; list: [7].")),
|
||||
((1,), (7, 8), re.escape("Tuple arity mismatch: 2 != 1; tuple: (7, 8).")),
|
||||
(
|
||||
{"a": 1},
|
||||
{"a": 7, "b": 8},
|
||||
re.escape(
|
||||
"Dict key mismatch; expected keys: ['a']; dict: {'a': 7, 'b': 8}."
|
||||
),
|
||||
),
|
||||
(
|
||||
{"a": 1},
|
||||
{"b": 7},
|
||||
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
||||
),
|
||||
([1], {"a": 7}, re.escape("Expected list, got {'a': 7}.")),
|
||||
([1], (7,), re.escape("Expected list, got (7,).")),
|
||||
((1,), [7], re.escape("Expected tuple, got [7].")),
|
||||
((1,), {"b": 7}, re.escape("Expected tuple, got {'b': 7}.")),
|
||||
({"a": 1}, (7,), re.escape("Expected dict, got (7,).")),
|
||||
({"a": 1}, [7], re.escape("Expected dict, got [7].")),
|
||||
([[1]], [7], re.escape("Expected list, got 7.")),
|
||||
([[1]], [(7,)], re.escape("Expected list, got (7,).")),
|
||||
([[1]], [{"a": 7}], re.escape("Expected list, got {'a': 7}.")),
|
||||
([(1,)], [7], re.escape("Expected tuple, got 7.")),
|
||||
([(1,)], [[7]], re.escape("Expected tuple, got [7].")),
|
||||
([(1,)], [{"a": 7}], re.escape("Expected tuple, got {'a': 7}.")),
|
||||
([{"a": 1}], [7], re.escape("Expected dict, got 7.")),
|
||||
([{"a": 1}], [[7]], re.escape("Expected dict, got [7].")),
|
||||
([{"a": 1}], [(7,)], re.escape("Expected dict, got (7,).")),
|
||||
(
|
||||
[{"a": 1}],
|
||||
[{"b": 7}],
|
||||
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
||||
),
|
||||
(([1],), (7,), re.escape("Expected list, got 7.")),
|
||||
(([1],), ((7,),), re.escape("Expected list, got (7,).")),
|
||||
(([1],), ({"a": 7},), re.escape("Expected list, got {'a': 7}.")),
|
||||
(((1,),), (7,), re.escape("Expected tuple, got 7.")),
|
||||
(((1,),), ([7],), re.escape("Expected tuple, got [7].")),
|
||||
(((1,),), ({"a": 7},), re.escape("Expected tuple, got {'a': 7}.")),
|
||||
(({"a": 1},), (7,), re.escape("Expected dict, got 7.")),
|
||||
(({"a": 1},), ([7],), re.escape("Expected dict, got [7].")),
|
||||
(({"a": 1},), ((7,),), re.escape("Expected dict, got (7,).")),
|
||||
(
|
||||
({"a": 1},),
|
||||
({"b": 7},),
|
||||
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
||||
),
|
||||
({"a": [1]}, {"a": 7}, re.escape("Expected list, got 7.")),
|
||||
({"a": [1]}, {"a": (7,)}, re.escape("Expected list, got (7,).")),
|
||||
({"a": [1]}, {"a": {"a": 7}}, re.escape("Expected list, got {'a': 7}.")),
|
||||
({"a": (1,)}, {"a": 7}, re.escape("Expected tuple, got 7.")),
|
||||
({"a": (1,)}, {"a": [7]}, re.escape("Expected tuple, got [7].")),
|
||||
(
|
||||
{"a": (1,)},
|
||||
{"a": {"a": 7}},
|
||||
re.escape("Expected tuple, got {'a': 7}."),
|
||||
),
|
||||
({"a": {"a": 1}}, {"a": 7}, re.escape("Expected dict, got 7.")),
|
||||
({"a": {"a": 1}}, {"a": [7]}, re.escape("Expected dict, got [7].")),
|
||||
({"a": {"a": 1}}, {"a": (7,)}, re.escape("Expected dict, got (7,).")),
|
||||
(
|
||||
{"a": {"a": 1}},
|
||||
{"a": {"b": 7}},
|
||||
re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."),
|
||||
),
|
||||
(
|
||||
[ATuple(foo=1, bar=2)],
|
||||
[(1, 2)],
|
||||
re.escape("Expected named tuple, got (1, 2)."),
|
||||
),
|
||||
(
|
||||
[ATuple(foo=1, bar=2)],
|
||||
[ATuple2(foo=1, bar=2)],
|
||||
re.escape("Named tuple type mismatch"),
|
||||
),
|
||||
(
|
||||
[AnObject(x=[1], y=(2,), z={"a": [1]})],
|
||||
[([1], (2,), {"a": [1]})],
|
||||
re.escape("Custom node type mismatch"),
|
||||
),
|
||||
)
|
||||
def testFlattenUpToErrors(self, tree, xs, error):
|
||||
_, tree_def = tree_util.tree_flatten(tree)
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
tree_def.flatten_up_to(xs)
|
||||
|
||||
def testTreeMap(self):
|
||||
x = ((1, 2), (3, 4, 5))
|
||||
|
Loading…
x
Reference in New Issue
Block a user