add more info to pytree prefix key errors

fixes #12643
This commit is contained in:
Matthew Johnson 2022-10-10 17:47:18 -07:00
parent ec5b1c93d7
commit b27acedf1f
3 changed files with 26 additions and 10 deletions

View File

@ -475,15 +475,27 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
# point, and because prefix_tree is not a leaf, each can be flattened once):
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
prefix_tree_keys = _child_keys(prefix_tree)
full_tree_keys = _child_keys(full_tree)
try:
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
except:
diff = None
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} children, "
f"with {len(prefix_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} children.".format(name=name))
f"type but with {len(full_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
.format(name=name)
+ ("" if diff is None else
f"so the symmetric difference on key sets is\n"
f" {' '.join(str(k.key) for k in diff)}"))
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different pytree metadata:
@ -510,11 +522,10 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
# If the root types and numbers of children agree, there must be an error
# in a subtree, so recurse:
keys = _child_keys(prefix_tree)
keys_ = _child_keys(full_tree)
assert keys == keys_, \
f"equal pytree nodes gave differing keys: {keys} and {keys_}"
for k, t1, t2 in zip(keys, prefix_tree_children, full_tree_children):
assert prefix_tree_keys == full_tree_keys, \
("equal pytree nodes gave differing prefix_tree_keys: "
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(key_path + k, t1, t2)

View File

@ -2603,9 +2603,7 @@ class PJitErrorTest(jtu.JaxTestCase):
" pjit out_axis_resources tree root\n"
"At that key path, the prefix pytree pjit out_axis_resources has a "
"subtree of type\n"
" <class 'list'>\n"
"with 2 children, but at the same key path the full pytree has a "
"subtree of the same type but with 3 children.")
" <class 'list'>\n")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message

View File

@ -505,6 +505,13 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')
def test_different_num_children_print_key_diff(self):
e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3})
expected = ("so the symmetric difference on key sets is\n"
" b")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_metadata(self):
e, = prefix_errors({1: 2}, {3: 4})
expected = ("pytree structure error: different pytree metadata "