mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make PyTreeDef
pickleable
PiperOrigin-RevId: 465306184
This commit is contained in:
parent
8610e61dbb
commit
e1b31f82fd
@ -14,6 +14,7 @@
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import pickle
|
||||
import re
|
||||
import unittest
|
||||
|
||||
@ -369,6 +370,13 @@ class TreeTest(jtu.JaxTestCase):
|
||||
def testTreeDefWithEmptyDictStringRepresentation(self):
|
||||
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
||||
|
||||
@parameterized.parameters(*TREES)
|
||||
@unittest.skipIf(pytree_version < 3, "Requires jaxlib 0.3.16")
|
||||
def testPickleRoundTrip(self, tree):
|
||||
treedef = tree_util.tree_structure(tree)
|
||||
treedef_restored = pickle.loads(pickle.dumps(treedef))
|
||||
self.assertEqual(treedef, treedef_restored)
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user