diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py
index 294943d2c..4ffda4bb1 100644
--- a/jax/_src/lax/control_flow/common.py
+++ b/jax/_src/lax/control_flow/common.py
@@ -30,7 +30,8 @@ from jax._src import util
 from jax._src.util import weakref_lru_cache, safe_map, partition_list
 from jax.api_util import flatten_fun_nokwargs
 from jax._src.interpreters import partial_eval as pe
-from jax.tree_util import tree_map, tree_unflatten
+from jax.tree_util import tree_map, tree_unflatten, keystr
+from jax._src.tree_util import equality_errors_pytreedef
 
 map, unsafe_map = safe_map, map
 
@@ -188,20 +189,28 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
   jaxpr = jaxpr.replace(effects=effects)
   return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
 
-def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
+def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2):
   """Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
 
   Corresponding `tree` and `avals` must match in the sense that the number of
-  leaves in `tree` must be equal to the length of `avals`. `what` will be
-  prepended to details of the mismatch in TypeError.
+  leaves in `tree` must be equal to the length of `avals`. `what1` and
+  `what2` describe what the `tree1` and `tree2` represent.
   """
   if tree1 != tree2:
-    raise TypeError(
-        f"{what} must have same type structure, got {tree1} and {tree2}.")
+    errs = list(equality_errors_pytreedef(tree1, tree2))
+    msg = []
+    msg.append(
+        f"{what1} must have same type structure as {what2}, but there are differences: ")
+    for path, thing1, thing2, explanation in errs:
+      msg.append(
+          f"    * at output{keystr(tuple(path))}, {what1} has {thing1} and "
+          f"{what2} has {thing2}, so {explanation}")
+    raise TypeError('\n'.join(msg))
+
   if not all(map(core.typematch, avals1, avals2)):
     diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
                     tree_unflatten(tree2, avals2))
-    raise TypeError(f"{what} must have identical types, got\n{diff}.")
+    raise TypeError(f"{what1} and {what2} must have identical types, got\n{diff}.")
 
 def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
   if has_aux:
diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py
index db0a1f4dc..971e00752 100644
--- a/jax/_src/lax/control_flow/conditionals.py
+++ b/jax/_src/lax/control_flow/conditionals.py
@@ -147,8 +147,9 @@ def switch(index, branches: Sequence[Callable], *operands,
   if config.mutable_array_checks.value:
     _check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
   for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
-    _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
+    _check_tree_and_avals("branch 0 output",
                           out_trees[0], jaxprs[0].out_avals,
+                          f"branch {i + 1} output",
                           out_tree, jaxpr.out_avals)
   joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
   disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
@@ -250,8 +251,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
          true_jaxpr.out_avals + false_jaxpr.out_avals):
     raise ValueError("Cannot return `Ref`s from `cond`.")
 
-  _check_tree_and_avals("true_fun and false_fun output",
+  _check_tree_and_avals("true_fun output",
                         out_tree, true_jaxpr.out_avals,
+                        "false_fun output",
                         false_out_tree, false_jaxpr.out_avals)
   # prune passhtrough outputs
   true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py
index 9c6e4f816..0f255747e 100644
--- a/jax/experimental/sparse/transform.py
+++ b/jax/experimental/sparse/transform.py
@@ -843,8 +843,9 @@ sparse_rules_bcoo[lax.scan_p] = _scan_sparse
 def _cond_sparse(spenv, pred, *operands, branches, **params):
   sp_branches, treedefs = zip(*(_sparsify_jaxpr(spenv, jaxpr, *operands)
                                 for jaxpr in branches))
-  _check_tree_and_avals("sparsified true_fun and false_fun output",
+  _check_tree_and_avals("sparsified true_fun output",
                         treedefs[0], sp_branches[0].out_avals,
+                        "sparsified false_fun output",
                         treedefs[1], sp_branches[1].out_avals)
   args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands)))
   out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params)
diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py
index f323b035d..68c5d45c9 100644
--- a/tests/lax_control_flow_test.py
+++ b/tests/lax_control_flow_test.py
@@ -995,12 +995,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
         re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
       lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
     with self.assertRaisesRegex(TypeError,
-        re.escape("true_fun and false_fun output must have same type structure, "
-                  f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
-      lax.cond(True, lambda top: 2., lambda fop: (3., 3.), 1.)
+        re.compile("true_fun output must have same type structure "
+                   "as false_fun output, but there are differences:.*"
+                   r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
+      lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
     with self.assertRaisesRegex(
         TypeError,
-        "true_fun and false_fun output must have identical types, got\n"
+        "true_fun output and false_fun output must have identical types, got\n"
         r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
         r"ShapedArray\(float32\[\].*\)."):
       lax.cond(True,
@@ -1023,16 +1024,17 @@ class LaxControlFlowTest(jtu.JaxTestCase):
         re.escape("Empty branch sequence")):
       lax.switch(0, [], 1.)
     with self.assertRaisesRegex(TypeError,
-        re.escape("branch 0 and 1 outputs must have same type structure, "
-                  f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
-      lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
+        re.compile("branch 0 output must have same type structure "
+                   "as branch 1 output, but there are differences:.*"
+                   r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
+      lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
     with self.assertRaisesRegex(
         TypeError,
-        "branch 0 and 1 outputs must have identical types, got\n"
-        r"DIFFERENT ShapedArray\(float32\[1\]\) "
-        r"vs. ShapedArray\(float32\[\].*\)."):
-      lax.switch(1, [lambda _: jnp.array([1.], jnp.float32),
-                     lambda _: jnp.float32(1.)],
+        "branch 0 output and branch 1 output must have identical types, got\n"
+        r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
+        r"vs. ShapedArray\(float32\[\].*\)'}."):
+      lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
+                     lambda _: dict(a=jnp.float32(1.))],
                  1.)
 
   def testCondOneBranchConstant(self):
diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py
index 46c2f5aaf..abef6ca56 100644
--- a/tests/sparsify_test.py
+++ b/tests/sparsify_test.py
@@ -560,7 +560,9 @@ class SparsifyTest(jtu.JaxTestCase):
     func(x, y)  # No error
     func(x_bcoo, y_bcoo)  # No error
 
-    with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
+    with self.assertRaisesRegex(
+        TypeError,
+        "sparsified true_fun output must have same type structure as sparsified false_fun output.*"):
       func(x_bcoo, y)
 
   @parameterized.named_parameters(