mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
parent
ec5b1c93d7
commit
b27acedf1f
@ -475,15 +475,27 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
# point, and because prefix_tree is not a leaf, each can be flattened once):
|
||||
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
|
||||
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
|
||||
prefix_tree_keys = _child_keys(prefix_tree)
|
||||
full_tree_keys = _child_keys(full_tree)
|
||||
try:
|
||||
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
|
||||
except:
|
||||
diff = None
|
||||
if len(prefix_tree_children) != len(full_tree_children):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different numbers of pytree children at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with {len(prefix_tree_children)} children, "
|
||||
f"with {len(prefix_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of the same "
|
||||
f"type but with {len(full_tree_children)} children.".format(name=name))
|
||||
f"type but with {len(full_tree_children)} child keys\n"
|
||||
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
|
||||
.format(name=name)
|
||||
+ ("" if diff is None else
|
||||
f"so the symmetric difference on key sets is\n"
|
||||
f" {' '.join(str(k.key) for k in diff)}"))
|
||||
return # don't look for more errors in this subtree
|
||||
|
||||
# Or they may disagree if their roots have different pytree metadata:
|
||||
@ -510,11 +522,10 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
|
||||
# If the root types and numbers of children agree, there must be an error
|
||||
# in a subtree, so recurse:
|
||||
keys = _child_keys(prefix_tree)
|
||||
keys_ = _child_keys(full_tree)
|
||||
assert keys == keys_, \
|
||||
f"equal pytree nodes gave differing keys: {keys} and {keys_}"
|
||||
for k, t1, t2 in zip(keys, prefix_tree_children, full_tree_children):
|
||||
assert prefix_tree_keys == full_tree_keys, \
|
||||
("equal pytree nodes gave differing prefix_tree_keys: "
|
||||
f"{prefix_tree_keys} and {full_tree_keys}")
|
||||
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
|
||||
yield from _prefix_error(key_path + k, t1, t2)
|
||||
|
||||
|
||||
|
@ -2603,9 +2603,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
" pjit out_axis_resources tree root\n"
|
||||
"At that key path, the prefix pytree pjit out_axis_resources has a "
|
||||
"subtree of type\n"
|
||||
" <class 'list'>\n"
|
||||
"with 2 children, but at the same key path the full pytree has a "
|
||||
"subtree of the same type but with 3 children.")
|
||||
" <class 'list'>\n")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message
|
||||
|
||||
|
@ -505,6 +505,13 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e2('in_axes')
|
||||
|
||||
def test_different_num_children_print_key_diff(self):
|
||||
e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3})
|
||||
expected = ("so the symmetric difference on key sets is\n"
|
||||
" b")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
def test_different_metadata(self):
|
||||
e, = prefix_errors({1: 2}, {3: 4})
|
||||
expected = ("pytree structure error: different pytree metadata "
|
||||
|
Loading…
x
Reference in New Issue
Block a user