Store sorted flattened dict keys in PyTree as a c++ vector instead of py::list to avoid creating new python object on every single dict flatten. For deeply nested dict, this avoids excessive gc pressure and avoids the slowdown whenever gc needs to sweep too many live python objects.

PiperOrigin-RevId: 502967020
This commit is contained in:
Qiao Zhang 2023-01-18 13:39:58 -08:00 committed by jax authors
parent 4add3b8cee
commit d58266eac7

View File

@ -371,12 +371,49 @@ class TreeTest(jtu.JaxTestCase):
treedef_restored = pickle.loads(pickle.dumps(treedef))
self.assertEqual(treedef, treedef_restored)
def testDictKeysSortable(self):
d = {"a": 1, 2: "b"}
with self.assertRaisesRegex(TypeError, "'<' not supported"):
_, _ = tree_util.tree_flatten(d)
def testFlattenDictKeyOrder(self):
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
leaves, treedef = tree_util.tree_flatten(d)
self.assertEqual(leaves, [1, 2, 1, 2])
self.assertEqual(
str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})"
)
restored_d = tree_util.tree_unflatten(treedef, leaves)
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
def testWalk(self):
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
leaves, treedef = tree_util.tree_flatten(d)
nodes_visited = []
node_data_visited = []
leaves_visited = []
def f_node(node, node_data):
nodes_visited.append(node)
node_data_visited.append(node_data)
def f_leaf(leaf):
leaves_visited.append(leaf)
treedef.walk(f_node, f_leaf, leaves)
self.assertEqual(leaves_visited, [1, 2, 1, 2])
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
class RavelUtilTest(jtu.JaxTestCase):
def testFloats(self):
tree = [jnp.array([3.], jnp.float32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
tree = [
jnp.array([3.0], jnp.float32),
jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32),
]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32)
tree_ = unravel(raveled)