Reduce the verbosity of treedef printing for custom nodes.

For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.

Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:

    >>> params = { .. some params .. }
    >>> f = jax.jit(..).lower(params).compile()
    >>> f(params)  # fine
    >>> params['some_new_thing'] = something
    >>> f(params)
    TypeError: function compiled for {treedef}, called with {treedef}.

PiperOrigin-RevId: 461190971
This commit is contained in:
Tom Hennigan 2022-07-15 07:14:00 -07:00 committed by jax authors
parent 023e6f5955
commit 10720258ea

View File

@ -15,6 +15,7 @@
import collections
import functools
import re
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -140,11 +141,10 @@ TREE_STRINGS = (
"PyTreeDef((*, *))",
"PyTreeDef(((*, *), [*, (*, None, *)]))",
"PyTreeDef([*])",
("PyTreeDef([*, CustomNode(namedtuple[<class '__main__.ATuple'>], [(*, "
"CustomNode(namedtuple[<class '__main__.ATuple'>], [*, None])), {'baz': "
"*}])])"),
"PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(<class '__main__.Special'>[None], [*, *]))",
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(Special[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
)
@ -352,6 +352,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(expected, actual)
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
@unittest.skipIf(pytree_version < 2, "Requires jaxlib 0.3.15")
def testStringRepresentation(self, tree, correct_string):
"""Checks that the string representation of a tree works."""
treedef = tree_util.tree_structure(tree)