mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
separate register_pytree_node and register_pytree_with_keys tests
This commit is contained in:
parent
d05cf13e94
commit
da3799959a
@ -57,10 +57,15 @@ class AnObject:
|
||||
def __repr__(self):
|
||||
return f"AnObject({self.x},{self.y},{self.z})"
|
||||
|
||||
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
|
||||
lambda z, xy: AnObject(xy[0], xy[1], z))
|
||||
|
||||
class AnObject2(AnObject): pass
|
||||
|
||||
tree_util.register_pytree_with_keys(
|
||||
AnObject,
|
||||
AnObject2,
|
||||
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
|
||||
lambda z, xy: AnObject(xy[0], xy[1], z), # unflatten (no key involved)
|
||||
lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved)
|
||||
)
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@ -133,6 +138,7 @@ TREES = (
|
||||
([3],),
|
||||
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
||||
([AnObject(3, None, [4, "foo"])],),
|
||||
([AnObject2(3, None, [4, "foo"])],),
|
||||
(Special(2, 3.),),
|
||||
({"a": 1, "b": 2},),
|
||||
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
@ -156,6 +162,7 @@ TREE_STRINGS = (
|
||||
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
|
||||
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
|
||||
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
|
||||
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
|
||||
"PyTreeDef(CustomNode(Special[None], [*, *]))",
|
||||
"PyTreeDef({'a': *, 'b': *})",
|
||||
)
|
||||
@ -185,7 +192,7 @@ TREES_WITH_KEYPATH = (
|
||||
(((1, "foo"), ["bar", (3, None, 7)]),),
|
||||
([3],),
|
||||
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
||||
([AnObject(3, None, [4, "foo"])],),
|
||||
([AnObject2(3, None, [4, "foo"])],),
|
||||
(SpecialWithKeys(2, 3.),),
|
||||
({"a": 1, "b": 0},),
|
||||
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
@ -454,13 +461,13 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
|
||||
|
||||
def testTreeMapWithPathMultipleTrees(self):
|
||||
tree1 = [AnObject(x=12,
|
||||
y={'cin': [1, 4, 10], 'bar': None},
|
||||
z='constantdef'),
|
||||
tree1 = [AnObject2(x=12,
|
||||
y={'cin': [1, 4, 10], 'bar': None},
|
||||
z='constantdef'),
|
||||
5]
|
||||
tree2 = [AnObject(x=2,
|
||||
y={'cin': [2, 2, 2], 'bar': None},
|
||||
z='constantdef'),
|
||||
tree2 = [AnObject2(x=2,
|
||||
y={'cin': [2, 2, 2], 'bar': None},
|
||||
z='constantdef'),
|
||||
2]
|
||||
from_two_trees = tree_util.tree_map_with_path(
|
||||
lambda kp, a, b: a + b, tree1, tree2
|
||||
@ -504,7 +511,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
||||
tree1 = {'a': 1,
|
||||
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
||||
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
|
||||
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
||||
self.assertEqual(
|
||||
@ -523,7 +530,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
||||
tree1 = {'a': 1,
|
||||
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
||||
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
self.assertEqual(flatten_one_level(tree1["sub"])[0],
|
||||
tree1["sub"])
|
||||
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
|
||||
|
Loading…
x
Reference in New Issue
Block a user