mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #23684 from simonster:sjk/fix-prefix-error
PiperOrigin-RevId: 686133952
This commit is contained in:
commit
e461c0496f
@ -1227,11 +1227,11 @@ def _prefix_error(
|
||||
if type(prefix_tree) != type(full_tree):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different types at key path\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {name}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of different type\n"
|
||||
f" {type(full_tree)}.".format(name=name))
|
||||
f" {type(full_tree)}.")
|
||||
return # don't look for more errors in this subtree
|
||||
|
||||
# Or they may disagree if their roots have different numbers or keys of
|
||||
@ -1251,11 +1251,10 @@ def _prefix_error(
|
||||
ty = type(prefix_tree)
|
||||
yield lambda name: ValueError(
|
||||
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type "
|
||||
f" {name}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {name} has a subtree of type "
|
||||
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
|
||||
f"has a subtree of the same type but of length {len(full_tree)}."
|
||||
.format(name=name))
|
||||
f"has a subtree of the same type but of length {len(full_tree)}.")
|
||||
return # don't look for more errors in this subtree
|
||||
else:
|
||||
# Next we handle the general case of checking child keys.
|
||||
@ -1266,15 +1265,14 @@ def _prefix_error(
|
||||
if len(prefix_tree_children) != len(full_tree_children):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different numbers of pytree children at key path\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {name}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with {len(prefix_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of the same "
|
||||
f"type but with {len(full_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
|
||||
.format(name=name)
|
||||
+ ("" if diff is None else
|
||||
f"so the symmetric difference on key sets is\n"
|
||||
f" {' '.join(str(k.key) for k in diff)}"))
|
||||
@ -1291,8 +1289,8 @@ def _prefix_error(
|
||||
prefix=" ")
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different pytree metadata at key path\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {name}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with metadata\n"
|
||||
f" {prefix_tree_meta_str}\n"
|
||||
@ -1300,7 +1298,7 @@ def _prefix_error(
|
||||
f"type but with metadata\n"
|
||||
f" {full_tree_meta_str}\n"
|
||||
f"so the diff in the metadata at these pytree nodes is\n"
|
||||
f"{metadata_diff}".format(name=name))
|
||||
f"{metadata_diff}")
|
||||
return # don't look for more errors in this subtree
|
||||
|
||||
# If the root types and numbers of children agree, there must be an error
|
||||
|
@ -1146,6 +1146,37 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_curly_braces_in_keys_no_children(self):
|
||||
e, = prefix_errors({"{oops}": {}}, {})
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
"at key path\n"
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_curly_braces_in_keys_list_length(self):
|
||||
e, = prefix_errors({"{oops}": []}, {"{oops}": [{}]})
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
r" in_axes\['{oops}'\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_curly_braces_in_keys_different_lengths(self):
|
||||
e, = prefix_errors({"{oops}": {}}, {"{oops}": 1})
|
||||
expected = ("pytree structure error: different types at key path\n"
|
||||
r" in_axes\['{oops}'\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_curly_braces_in_keys_different_metadata(self):
|
||||
e, = prefix_errors({"{oops}": {"{a}": 1}}, {"{oops}": {"{b}": 1}})
|
||||
expected = ("pytree structure error: different pytree metadata "
|
||||
"at key path\n"
|
||||
r" in_axes\['{oops}'\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
|
||||
class TreeAliasTest(jtu.JaxTestCase):
|
||||
"""Simple smoke-tests for tree_util aliases under jax.tree"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user