mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Store sorted flattened dict keys in PyTree as a c++ vector instead of py::list to avoid creating new python object on every single dict flatten. For deeply nested dict, this avoids excessive gc pressure and avoids the slowdown whenever gc needs to sweep too many live python objects.
PiperOrigin-RevId: 502967020
This commit is contained in:
parent
4add3b8cee
commit
d58266eac7
@ -371,12 +371,49 @@ class TreeTest(jtu.JaxTestCase):
|
||||
treedef_restored = pickle.loads(pickle.dumps(treedef))
|
||||
self.assertEqual(treedef, treedef_restored)
|
||||
|
||||
def testDictKeysSortable(self):
|
||||
d = {"a": 1, 2: "b"}
|
||||
with self.assertRaisesRegex(TypeError, "'<' not supported"):
|
||||
_, _ = tree_util.tree_flatten(d)
|
||||
|
||||
def testFlattenDictKeyOrder(self):
|
||||
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
||||
leaves, treedef = tree_util.tree_flatten(d)
|
||||
self.assertEqual(leaves, [1, 2, 1, 2])
|
||||
self.assertEqual(
|
||||
str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})"
|
||||
)
|
||||
restored_d = tree_util.tree_unflatten(treedef, leaves)
|
||||
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
|
||||
|
||||
def testWalk(self):
|
||||
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
||||
leaves, treedef = tree_util.tree_flatten(d)
|
||||
|
||||
nodes_visited = []
|
||||
node_data_visited = []
|
||||
leaves_visited = []
|
||||
|
||||
def f_node(node, node_data):
|
||||
nodes_visited.append(node)
|
||||
node_data_visited.append(node_data)
|
||||
|
||||
def f_leaf(leaf):
|
||||
leaves_visited.append(leaf)
|
||||
|
||||
treedef.walk(f_node, f_leaf, leaves)
|
||||
self.assertEqual(leaves_visited, [1, 2, 1, 2])
|
||||
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
|
||||
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
def testFloats(self):
|
||||
tree = [jnp.array([3.], jnp.float32),
|
||||
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
||||
tree = [
|
||||
jnp.array([3.0], jnp.float32),
|
||||
jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32),
|
||||
]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, jnp.float32)
|
||||
tree_ = unravel(raveled)
|
||||
|
Loading…
x
Reference in New Issue
Block a user