[JAX] Drop the private _process_pytree method from tree_util.

Removing this tested but otherwise unused method makes it easier to merge https://github.com/tensorflow/tensorflow/pull/56202 which changes the API contract of (undocumented) method .walk().

Technically speaking changing the contract of .walk() breaks backward compatibility, but we've never advertised its existence and as far as I can tell it has no users in the code I have access to.

PiperOrigin-RevId: 455687311
This commit is contained in:
Peter Hawkins 2022-06-17 13:45:04 -07:00 committed by jax authors
parent cc0f51603d
commit 7782feeb6c
2 changed files with 1 additions and 12 deletions

View File

@ -201,11 +201,6 @@ def tree_multimap(*args, **kwargs):
'instead as a drop-in replacement.', FutureWarning)
return tree_map(*args, **kwargs)
# TODO(mattjj,phawkins): consider removing this function
def _process_pytree(process_node, tree):
leaves, treedef = pytree.flatten(tree)
return treedef.walk(process_node, None, leaves), treedef
def build_tree(treedef, xs):
return treedef.from_iterable_tree(xs)

View File

@ -24,7 +24,7 @@ from jax import tree_util
from jax import flatten_util
from jax._src import test_util as jtu
from jax._src.lib import pytree as pytree
from jax._src.tree_util import _process_pytree, prefix_errors
from jax._src.tree_util import prefix_errors
import jax.numpy as jnp
@ -210,12 +210,6 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(p1.func, p2.func)
self.assertEqual(hash(p1.func), hash(p2.func))
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripViaBuild(self, inputs):
xs, tree = _process_pytree(tuple, inputs)
actual = tree_util.build_tree(tree, xs)
self.assertEqual(actual, inputs)
def testChildren(self):
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
_, c0 = tree_util.tree_flatten((0, 0, 0))