Merge pull request #23684 from simonster:sjk/fix-prefix-error

PiperOrigin-RevId: 686133952
This commit is contained in:
jax authors 2024-10-15 09:32:30 -07:00
commit e461c0496f
2 changed files with 42 additions and 13 deletions

View File

@ -1227,11 +1227,11 @@ def _prefix_error(
if type(prefix_tree) != type(full_tree):
yield lambda name: ValueError(
"pytree structure error: different types at key path\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {name}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {name} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"but at the same key path the full pytree has a subtree of different type\n"
f" {type(full_tree)}.".format(name=name))
f" {type(full_tree)}.")
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different numbers or keys of
@ -1251,11 +1251,10 @@ def _prefix_error(
ty = type(prefix_tree)
yield lambda name: ValueError(
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type "
f" {name}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {name} has a subtree of type "
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
f"has a subtree of the same type but of length {len(full_tree)}."
.format(name=name))
f"has a subtree of the same type but of length {len(full_tree)}.")
return # don't look for more errors in this subtree
else:
# Next we handle the general case of checking child keys.
@ -1266,15 +1265,14 @@ def _prefix_error(
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {name}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {name} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
.format(name=name)
+ ("" if diff is None else
f"so the symmetric difference on key sets is\n"
f" {' '.join(str(k.key) for k in diff)}"))
@ -1291,8 +1289,8 @@ def _prefix_error(
prefix=" ")
yield lambda name: ValueError(
"pytree structure error: different pytree metadata at key path\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {name}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {name} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with metadata\n"
f" {prefix_tree_meta_str}\n"
@ -1300,7 +1298,7 @@ def _prefix_error(
f"type but with metadata\n"
f" {full_tree_meta_str}\n"
f"so the diff in the metadata at these pytree nodes is\n"
f"{metadata_diff}".format(name=name))
f"{metadata_diff}")
return # don't look for more errors in this subtree
# If the root types and numbers of children agree, there must be an error

View File

@ -1146,6 +1146,37 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_curly_braces_in_keys_no_children(self):
e, = prefix_errors({"{oops}": {}}, {})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_curly_braces_in_keys_list_length(self):
e, = prefix_errors({"{oops}": []}, {"{oops}": [{}]})
expected = ("pytree structure error: different lengths of list "
"at key path\n"
r" in_axes\['{oops}'\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_curly_braces_in_keys_different_lengths(self):
e, = prefix_errors({"{oops}": {}}, {"{oops}": 1})
expected = ("pytree structure error: different types at key path\n"
r" in_axes\['{oops}'\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
def test_curly_braces_in_keys_different_metadata(self):
e, = prefix_errors({"{oops}": {"{a}": 1}}, {"{oops}": {"{b}": 1}})
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
r" in_axes\['{oops}'\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
class TreeAliasTest(jtu.JaxTestCase):
"""Simple smoke-tests for tree_util aliases under jax.tree"""