mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Drop the private _process_pytree method from tree_util.
Removing this tested but otherwise unused method makes it easier to merge https://github.com/tensorflow/tensorflow/pull/56202 which changes the API contract of (undocumented) method .walk(). Technically speaking changing the contract of .walk() breaks backward compatibility, but we've never advertised its existence and as far as I can tell it has no users in the code I have access to. PiperOrigin-RevId: 455687311
This commit is contained in:
parent
cc0f51603d
commit
7782feeb6c
@ -201,11 +201,6 @@ def tree_multimap(*args, **kwargs):
|
||||
'instead as a drop-in replacement.', FutureWarning)
|
||||
return tree_map(*args, **kwargs)
|
||||
|
||||
# TODO(mattjj,phawkins): consider removing this function
|
||||
def _process_pytree(process_node, tree):
|
||||
leaves, treedef = pytree.flatten(tree)
|
||||
return treedef.walk(process_node, None, leaves), treedef
|
||||
|
||||
def build_tree(treedef, xs):
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
|
@ -24,7 +24,7 @@ from jax import tree_util
|
||||
from jax import flatten_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import pytree as pytree
|
||||
from jax._src.tree_util import _process_pytree, prefix_errors
|
||||
from jax._src.tree_util import prefix_errors
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -210,12 +210,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(p1.func, p2.func)
|
||||
self.assertEqual(hash(p1.func), hash(p2.func))
|
||||
|
||||
@parameterized.parameters(*(TREES + LEAVES))
|
||||
def testRoundtripViaBuild(self, inputs):
|
||||
xs, tree = _process_pytree(tuple, inputs)
|
||||
actual = tree_util.build_tree(tree, xs)
|
||||
self.assertEqual(actual, inputs)
|
||||
|
||||
def testChildren(self):
|
||||
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
|
||||
_, c0 = tree_util.tree_flatten((0, 0, 0))
|
||||
|
Loading…
x
Reference in New Issue
Block a user