mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reduce the verbosity of treedef printing for custom nodes.
For very large trees of custom nodes this printing can be very verbose with a lot or repetition. Our internal repository also encourages very deep package names which exacerbates this issue. Users encounter treedef printing when interacting with some staging APIs in JAX, for example: >>> params = { .. some params .. } >>> f = jax.jit(..).lower(params).compile() >>> f(params) # fine >>> params['some_new_thing'] = something >>> f(params) TypeError: function compiled for {treedef}, called with {treedef}. PiperOrigin-RevId: 461190971
This commit is contained in:
parent
023e6f5955
commit
10720258ea
@ -15,6 +15,7 @@
|
||||
import collections
|
||||
import functools
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -140,11 +141,10 @@ TREE_STRINGS = (
|
||||
"PyTreeDef((*, *))",
|
||||
"PyTreeDef(((*, *), [*, (*, None, *)]))",
|
||||
"PyTreeDef([*])",
|
||||
("PyTreeDef([*, CustomNode(namedtuple[<class '__main__.ATuple'>], [(*, "
|
||||
"CustomNode(namedtuple[<class '__main__.ATuple'>], [*, None])), {'baz': "
|
||||
"*}])])"),
|
||||
"PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])",
|
||||
"PyTreeDef(CustomNode(<class '__main__.Special'>[None], [*, *]))",
|
||||
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
|
||||
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
|
||||
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
|
||||
"PyTreeDef(CustomNode(Special[None], [*, *]))",
|
||||
"PyTreeDef({'a': *, 'b': *})",
|
||||
)
|
||||
|
||||
@ -352,6 +352,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
|
||||
@unittest.skipIf(pytree_version < 2, "Requires jaxlib 0.3.15")
|
||||
def testStringRepresentation(self, tree, correct_string):
|
||||
"""Checks that the string representation of a tree works."""
|
||||
treedef = tree_util.tree_structure(tree)
|
||||
|
Loading…
x
Reference in New Issue
Block a user