diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 1a63fd1c9..784e1a6bb 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -492,33 +492,49 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, f" {type(full_tree)}.".format(name=name)) return # don't look for more errors in this subtree - # Or they may disagree if their roots have different numbers of children (note - # that because both prefix_tree and full_tree have the same type at this - # point, and because prefix_tree is not a leaf, each can be flattened once): + # Or they may disagree if their roots have different numbers or keys of + # children. Because both prefix_tree and full_tree have the same type at this + # point, and because prefix_tree is not a leaf, each can be flattened once: prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree) full_tree_children, full_tree_meta = flatten_one_level(full_tree) prefix_tree_keys = _child_keys(prefix_tree) full_tree_keys = _child_keys(full_tree) - try: - diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys)) - except: - diff = None - if len(prefix_tree_children) != len(full_tree_children): - yield lambda name: ValueError( - "pytree structure error: different numbers of pytree children at key path\n" - f" {{name}}{key_path.pprint()}\n" - f"At that key path, the prefix pytree {{name}} has a subtree of type\n" - f" {type(prefix_tree)}\n" - f"with {len(prefix_tree_children)} child keys\n" - f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n" - f"but at the same key path the full pytree has a subtree of the same " - f"type but with {len(full_tree_children)} child keys\n" - f" {' '.join(str(k.key) for k in full_tree_keys)}\n" - .format(name=name) - + ("" if diff is None else - f"so the symmetric difference on key sets is\n" - f" {' '.join(str(k.key) for k in diff)}")) - return # don't look for more errors in this subtree + # First we check special case types (list and tuple, though if they were + # pytrees we could check strings and sets here, basically Sequences) so that + # we can report length disagreement rather than integer keys: + if isinstance(prefix_tree, (list, tuple)): + if len(prefix_tree) != len(full_tree): + ty = type(prefix_tree) + yield lambda name: ValueError( + f"pytree structure error: different lengths of {ty.__name__} at key path\n" + f" {{name}}{key_path.pprint()}\n" + f"At that key path, the prefix pytree {{name}} has a subtree of type " + f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree " + f"has a subtree of the same type but of length {len(full_tree)}." + .format(name=name)) + return # don't look for more errors in this subtree + else: + # Next we handle the general case of checking child keys. + try: + diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys)) + except: + diff = None + if len(prefix_tree_children) != len(full_tree_children): + yield lambda name: ValueError( + "pytree structure error: different numbers of pytree children at key path\n" + f" {{name}}{key_path.pprint()}\n" + f"At that key path, the prefix pytree {{name}} has a subtree of type\n" + f" {type(prefix_tree)}\n" + f"with {len(prefix_tree_children)} child keys\n" + f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n" + f"but at the same key path the full pytree has a subtree of the same " + f"type but with {len(full_tree_children)} child keys\n" + f" {' '.join(str(k.key) for k in full_tree_keys)}\n" + .format(name=name) + + ("" if diff is None else + f"so the symmetric difference on key sets is\n" + f" {' '.join(str(k.key) for k in diff)}")) + return # don't look for more errors in this subtree # Or they may disagree if their roots have different pytree metadata: if prefix_tree_meta != full_tree_meta: diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 0c51d4af6..890a834fc 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -507,8 +507,25 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes') - def test_different_num_children(self): + def test_different_num_children_tuple(self): e, = prefix_errors((1,), (2, 3)) + expected = ("pytree structure error: different lengths of tuple " + "at key path\n" + " in_axes tree root") + with self.assertRaisesRegex(ValueError, expected): + raise e('in_axes') + + def test_different_num_children_list(self): + e, = prefix_errors([1], [2, 3]) + expected = ("pytree structure error: different lengths of list " + "at key path\n" + " in_axes tree root") + with self.assertRaisesRegex(ValueError, expected): + raise e('in_axes') + + + def test_different_num_children_generic(self): + e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3}) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" " in_axes tree root") @@ -517,7 +534,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): def test_different_num_children_nested(self): e, = prefix_errors([[1]], [[2, 3]]) - expected = ("pytree structure error: different numbers of pytree children " + expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): @@ -525,12 +542,12 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): def test_different_num_children_multiple(self): e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]]) - expected = ("pytree structure error: different numbers of pytree children " + expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') - expected = ("pytree structure error: different numbers of pytree children " + expected = ("pytree structure error: different lengths of list " "at key path\n" r" in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected):