specialize tree prefix error message for list/tuple

This commit is contained in:
Matthew Johnson 2023-01-20 10:51:02 -08:00
parent 068423bb96
commit cea2b6b6f8
2 changed files with 60 additions and 27 deletions

View File

@ -492,33 +492,49 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
f" {type(full_tree)}.".format(name=name))
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different numbers of children (note
# that because both prefix_tree and full_tree have the same type at this
# point, and because prefix_tree is not a leaf, each can be flattened once):
# Or they may disagree if their roots have different numbers or keys of
# children. Because both prefix_tree and full_tree have the same type at this
# point, and because prefix_tree is not a leaf, each can be flattened once:
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
prefix_tree_keys = _child_keys(prefix_tree)
full_tree_keys = _child_keys(full_tree)
try:
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
except:
diff = None
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
.format(name=name)
+ ("" if diff is None else
f"so the symmetric difference on key sets is\n"
f" {' '.join(str(k.key) for k in diff)}"))
return # don't look for more errors in this subtree
# First we check special case types (list and tuple, though if they were
# pytrees we could check strings and sets here, basically Sequences) so that
# we can report length disagreement rather than integer keys:
if isinstance(prefix_tree, (list, tuple)):
if len(prefix_tree) != len(full_tree):
ty = type(prefix_tree)
yield lambda name: ValueError(
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type "
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
f"has a subtree of the same type but of length {len(full_tree)}."
.format(name=name))
return # don't look for more errors in this subtree
else:
# Next we handle the general case of checking child keys.
try:
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
except:
diff = None
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
.format(name=name)
+ ("" if diff is None else
f"so the symmetric difference on key sets is\n"
f" {' '.join(str(k.key) for k in diff)}"))
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different pytree metadata:
if prefix_tree_meta != full_tree_meta:

View File

@ -507,8 +507,25 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')
def test_different_num_children(self):
def test_different_num_children_tuple(self):
e, = prefix_errors((1,), (2, 3))
expected = ("pytree structure error: different lengths of tuple "
"at key path\n"
" in_axes tree root")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_list(self):
e, = prefix_errors([1], [2, 3])
expected = ("pytree structure error: different lengths of list "
"at key path\n"
" in_axes tree root")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_different_num_children_generic(self):
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes tree root")
@ -517,7 +534,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
def test_different_num_children_nested(self):
e, = prefix_errors([[1]], [[2, 3]])
expected = ("pytree structure error: different numbers of pytree children "
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
@ -525,12 +542,12 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
def test_different_num_children_multiple(self):
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
expected = ("pytree structure error: different numbers of pytree children "
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different numbers of pytree children "
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):