[better_errors] Improve error message for lax.switch branches output structure mismatch

Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
This commit is contained in:
George Necula 2025-01-09 14:15:39 +02:00
parent 640cb009f1
commit c2adfbf1c2
5 changed files with 39 additions and 23 deletions

View File

@ -30,7 +30,8 @@ from jax._src import util
from jax._src.util import weakref_lru_cache, safe_map, partition_list
from jax.api_util import flatten_fun_nokwargs
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten
from jax.tree_util import tree_map, tree_unflatten, keystr
from jax._src.tree_util import equality_errors_pytreedef
map, unsafe_map = safe_map, map
@ -188,20 +189,28 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
jaxpr = jaxpr.replace(effects=effects)
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2):
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
Corresponding `tree` and `avals` must match in the sense that the number of
leaves in `tree` must be equal to the length of `avals`. `what` will be
prepended to details of the mismatch in TypeError.
leaves in `tree` must be equal to the length of `avals`. `what1` and
`what2` describe what the `tree1` and `tree2` represent.
"""
if tree1 != tree2:
raise TypeError(
f"{what} must have same type structure, got {tree1} and {tree2}.")
errs = list(equality_errors_pytreedef(tree1, tree2))
msg = []
msg.append(
f"{what1} must have same type structure as {what2}, but there are differences: ")
for path, thing1, thing2, explanation in errs:
msg.append(
f" * at output{keystr(tuple(path))}, {what1} has {thing1} and "
f"{what2} has {thing2}, so {explanation}")
raise TypeError('\n'.join(msg))
if not all(map(core.typematch, avals1, avals2)):
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2))
raise TypeError(f"{what} must have identical types, got\n{diff}.")
raise TypeError(f"{what1} and {what2} must have identical types, got\n{diff}.")
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
if has_aux:

View File

@ -147,8 +147,9 @@ def switch(index, branches: Sequence[Callable], *operands,
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
_check_tree_and_avals("branch 0 output",
out_trees[0], jaxprs[0].out_avals,
f"branch {i + 1} output",
out_tree, jaxpr.out_avals)
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
@ -250,8 +251,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
true_jaxpr.out_avals + false_jaxpr.out_avals):
raise ValueError("Cannot return `Ref`s from `cond`.")
_check_tree_and_avals("true_fun and false_fun output",
_check_tree_and_avals("true_fun output",
out_tree, true_jaxpr.out_avals,
"false_fun output",
false_out_tree, false_jaxpr.out_avals)
# prune passhtrough outputs
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)

View File

@ -843,8 +843,9 @@ sparse_rules_bcoo[lax.scan_p] = _scan_sparse
def _cond_sparse(spenv, pred, *operands, branches, **params):
sp_branches, treedefs = zip(*(_sparsify_jaxpr(spenv, jaxpr, *operands)
for jaxpr in branches))
_check_tree_and_avals("sparsified true_fun and false_fun output",
_check_tree_and_avals("sparsified true_fun output",
treedefs[0], sp_branches[0].out_avals,
"sparsified false_fun output",
treedefs[1], sp_branches[1].out_avals)
args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands)))
out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params)

View File

@ -995,12 +995,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
with self.assertRaisesRegex(TypeError,
re.escape("true_fun and false_fun output must have same type structure, "
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
lax.cond(True, lambda top: 2., lambda fop: (3., 3.), 1.)
re.compile("true_fun output must have same type structure "
"as false_fun output, but there are differences:.*"
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
with self.assertRaisesRegex(
TypeError,
"true_fun and false_fun output must have identical types, got\n"
"true_fun output and false_fun output must have identical types, got\n"
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
r"ShapedArray\(float32\[\].*\)."):
lax.cond(True,
@ -1023,16 +1024,17 @@ class LaxControlFlowTest(jtu.JaxTestCase):
re.escape("Empty branch sequence")):
lax.switch(0, [], 1.)
with self.assertRaisesRegex(TypeError,
re.escape("branch 0 and 1 outputs must have same type structure, "
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
re.compile("branch 0 output must have same type structure "
"as branch 1 output, but there are differences:.*"
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
with self.assertRaisesRegex(
TypeError,
"branch 0 and 1 outputs must have identical types, got\n"
r"DIFFERENT ShapedArray\(float32\[1\]\) "
r"vs. ShapedArray\(float32\[\].*\)."):
lax.switch(1, [lambda _: jnp.array([1.], jnp.float32),
lambda _: jnp.float32(1.)],
"branch 0 output and branch 1 output must have identical types, got\n"
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
r"vs. ShapedArray\(float32\[\].*\)'}."):
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
lambda _: dict(a=jnp.float32(1.))],
1.)
def testCondOneBranchConstant(self):

View File

@ -560,7 +560,9 @@ class SparsifyTest(jtu.JaxTestCase):
func(x, y) # No error
func(x_bcoo, y_bcoo) # No error
with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
with self.assertRaisesRegex(
TypeError,
"sparsified true_fun output must have same type structure as sparsified false_fun output.*"):
func(x_bcoo, y)
@parameterized.named_parameters(