separate register_pytree_node and register_pytree_with_keys tests

This commit is contained in:
Matthew Johnson 2023-03-18 14:32:19 -07:00
parent d05cf13e94
commit da3799959a

View File

@ -57,10 +57,15 @@ class AnObject:
def __repr__(self):
return f"AnObject({self.x},{self.y},{self.z})"
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
lambda z, xy: AnObject(xy[0], xy[1], z))
class AnObject2(AnObject): pass
tree_util.register_pytree_with_keys(
AnObject,
AnObject2,
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
lambda z, xy: AnObject(xy[0], xy[1], z), # unflatten (no key involved)
lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved)
)
@tree_util.register_pytree_node_class
@ -133,6 +138,7 @@ TREES = (
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
([AnObject2(3, None, [4, "foo"])],),
(Special(2, 3.),),
({"a": 1, "b": 2},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
@ -156,6 +162,7 @@ TREE_STRINGS = (
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(Special[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
)
@ -185,7 +192,7 @@ TREES_WITH_KEYPATH = (
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
([AnObject2(3, None, [4, "foo"])],),
(SpecialWithKeys(2, 3.),),
({"a": 1, "b": 0},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
@ -454,13 +461,13 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
def testTreeMapWithPathMultipleTrees(self):
tree1 = [AnObject(x=12,
y={'cin': [1, 4, 10], 'bar': None},
z='constantdef'),
tree1 = [AnObject2(x=12,
y={'cin': [1, 4, 10], 'bar': None},
z='constantdef'),
5]
tree2 = [AnObject(x=2,
y={'cin': [2, 2, 2], 'bar': None},
z='constantdef'),
tree2 = [AnObject2(x=2,
y={'cin': [2, 2, 2], 'bar': None},
z='constantdef'),
2]
from_two_trees = tree_util.tree_map_with_path(
lambda kp, a, b: a + b, tree1, tree2
@ -504,7 +511,7 @@ class TreeTest(jtu.JaxTestCase):
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
self.assertEqual(
@ -523,7 +530,7 @@ class TreeTest(jtu.JaxTestCase):
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
self.assertEqual(flatten_one_level(tree1["sub"])[0],
tree1["sub"])
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],