Make PyTreeDef pickleable

PiperOrigin-RevId: 465306184
This commit is contained in:
jax authors 2022-08-04 07:13:01 -07:00
parent 8610e61dbb
commit e1b31f82fd

View File

@ -14,6 +14,7 @@
import collections
import functools
import pickle
import re
import unittest
@ -369,6 +370,13 @@ class TreeTest(jtu.JaxTestCase):
def testTreeDefWithEmptyDictStringRepresentation(self):
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
@parameterized.parameters(*TREES)
@unittest.skipIf(pytree_version < 3, "Requires jaxlib 0.3.16")
def testPickleRoundTrip(self, tree):
treedef = tree_util.tree_structure(tree)
treedef_restored = pickle.loads(pickle.dumps(treedef))
self.assertEqual(treedef, treedef_restored)
class RavelUtilTest(jtu.JaxTestCase):