mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
specialize tree prefix error message for list/tuple
This commit is contained in:
parent
068423bb96
commit
cea2b6b6f8
@ -492,33 +492,49 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
f" {type(full_tree)}.".format(name=name))
|
||||
return # don't look for more errors in this subtree
|
||||
|
||||
# Or they may disagree if their roots have different numbers of children (note
|
||||
# that because both prefix_tree and full_tree have the same type at this
|
||||
# point, and because prefix_tree is not a leaf, each can be flattened once):
|
||||
# Or they may disagree if their roots have different numbers or keys of
|
||||
# children. Because both prefix_tree and full_tree have the same type at this
|
||||
# point, and because prefix_tree is not a leaf, each can be flattened once:
|
||||
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
|
||||
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
|
||||
prefix_tree_keys = _child_keys(prefix_tree)
|
||||
full_tree_keys = _child_keys(full_tree)
|
||||
try:
|
||||
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
|
||||
except:
|
||||
diff = None
|
||||
if len(prefix_tree_children) != len(full_tree_children):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different numbers of pytree children at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with {len(prefix_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of the same "
|
||||
f"type but with {len(full_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
|
||||
.format(name=name)
|
||||
+ ("" if diff is None else
|
||||
f"so the symmetric difference on key sets is\n"
|
||||
f" {' '.join(str(k.key) for k in diff)}"))
|
||||
return # don't look for more errors in this subtree
|
||||
# First we check special case types (list and tuple, though if they were
|
||||
# pytrees we could check strings and sets here, basically Sequences) so that
|
||||
# we can report length disagreement rather than integer keys:
|
||||
if isinstance(prefix_tree, (list, tuple)):
|
||||
if len(prefix_tree) != len(full_tree):
|
||||
ty = type(prefix_tree)
|
||||
yield lambda name: ValueError(
|
||||
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type "
|
||||
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
|
||||
f"has a subtree of the same type but of length {len(full_tree)}."
|
||||
.format(name=name))
|
||||
return # don't look for more errors in this subtree
|
||||
else:
|
||||
# Next we handle the general case of checking child keys.
|
||||
try:
|
||||
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
|
||||
except:
|
||||
diff = None
|
||||
if len(prefix_tree_children) != len(full_tree_children):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different numbers of pytree children at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with {len(prefix_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of the same "
|
||||
f"type but with {len(full_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
|
||||
.format(name=name)
|
||||
+ ("" if diff is None else
|
||||
f"so the symmetric difference on key sets is\n"
|
||||
f" {' '.join(str(k.key) for k in diff)}"))
|
||||
return # don't look for more errors in this subtree
|
||||
|
||||
# Or they may disagree if their roots have different pytree metadata:
|
||||
if prefix_tree_meta != full_tree_meta:
|
||||
|
@ -507,8 +507,25 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e2('in_axes')
|
||||
|
||||
def test_different_num_children(self):
|
||||
def test_different_num_children_tuple(self):
|
||||
e, = prefix_errors((1,), (2, 3))
|
||||
expected = ("pytree structure error: different lengths of tuple "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_different_num_children_list(self):
|
||||
e, = prefix_errors([1], [2, 3])
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
|
||||
def test_different_num_children_generic(self):
|
||||
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
@ -517,7 +534,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_different_num_children_nested(self):
|
||||
e, = prefix_errors([[1]], [[2, 3]])
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
r" in_axes\[0\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
@ -525,12 +542,12 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_different_num_children_multiple(self):
|
||||
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
r" in_axes\[0\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e1('in_axes')
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
r" in_axes\[1\]")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
|
Loading…
x
Reference in New Issue
Block a user