diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 294943d2c..4ffda4bb1 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -30,7 +30,8 @@ from jax._src import util from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax.api_util import flatten_fun_nokwargs from jax._src.interpreters import partial_eval as pe -from jax.tree_util import tree_map, tree_unflatten +from jax.tree_util import tree_map, tree_unflatten, keystr +from jax._src.tree_util import equality_errors_pytreedef map, unsafe_map = safe_map, map @@ -188,20 +189,28 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, jaxpr = jaxpr.replace(effects=effects) return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) -def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): +def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2): """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). Corresponding `tree` and `avals` must match in the sense that the number of - leaves in `tree` must be equal to the length of `avals`. `what` will be - prepended to details of the mismatch in TypeError. + leaves in `tree` must be equal to the length of `avals`. `what1` and + `what2` describe what the `tree1` and `tree2` represent. """ if tree1 != tree2: - raise TypeError( - f"{what} must have same type structure, got {tree1} and {tree2}.") + errs = list(equality_errors_pytreedef(tree1, tree2)) + msg = [] + msg.append( + f"{what1} must have same type structure as {what2}, but there are differences: ") + for path, thing1, thing2, explanation in errs: + msg.append( + f" * at output{keystr(tuple(path))}, {what1} has {thing1} and " + f"{what2} has {thing2}, so {explanation}") + raise TypeError('\n'.join(msg)) + if not all(map(core.typematch, avals1, avals2)): diff = tree_map(_show_diff, tree_unflatten(tree1, avals1), tree_unflatten(tree2, avals2)) - raise TypeError(f"{what} must have identical types, got\n{diff}.") + raise TypeError(f"{what1} and {what2} must have identical types, got\n{diff}.") def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False): if has_aux: diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index db0a1f4dc..971e00752 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -147,8 +147,9 @@ def switch(index, branches: Sequence[Callable], *operands, if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): - _check_tree_and_avals(f"branch 0 and {i + 1} outputs", + _check_tree_and_avals("branch 0 output", out_trees[0], jaxprs[0].out_avals, + f"branch {i + 1} output", out_tree, jaxpr.out_avals) joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -250,8 +251,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") - _check_tree_and_avals("true_fun and false_fun output", + _check_tree_and_avals("true_fun output", out_tree, true_jaxpr.out_avals, + "false_fun output", false_out_tree, false_jaxpr.out_avals) # prune passhtrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 9c6e4f816..0f255747e 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -843,8 +843,9 @@ sparse_rules_bcoo[lax.scan_p] = _scan_sparse def _cond_sparse(spenv, pred, *operands, branches, **params): sp_branches, treedefs = zip(*(_sparsify_jaxpr(spenv, jaxpr, *operands) for jaxpr in branches)) - _check_tree_and_avals("sparsified true_fun and false_fun output", + _check_tree_and_avals("sparsified true_fun output", treedefs[0], sp_branches[0].out_avals, + "sparsified false_fun output", treedefs[1], sp_branches[1].out_avals) args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands))) out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index f323b035d..68c5d45c9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -995,12 +995,13 @@ class LaxControlFlowTest(jtu.JaxTestCase): re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")): lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.) with self.assertRaisesRegex(TypeError, - re.escape("true_fun and false_fun output must have same type structure, " - f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")): - lax.cond(True, lambda top: 2., lambda fop: (3., 3.), 1.) + re.compile("true_fun output must have same type structure " + "as false_fun output, but there are differences:.*" + r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)): + lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.) with self.assertRaisesRegex( TypeError, - "true_fun and false_fun output must have identical types, got\n" + "true_fun output and false_fun output must have identical types, got\n" r"DIFFERENT ShapedArray\(float32\[1\]\) vs. " r"ShapedArray\(float32\[\].*\)."): lax.cond(True, @@ -1023,16 +1024,17 @@ class LaxControlFlowTest(jtu.JaxTestCase): re.escape("Empty branch sequence")): lax.switch(0, [], 1.) with self.assertRaisesRegex(TypeError, - re.escape("branch 0 and 1 outputs must have same type structure, " - f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")): - lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.) + re.compile("branch 0 output must have same type structure " + "as branch 1 output, but there are differences:.*" + r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)): + lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.) with self.assertRaisesRegex( TypeError, - "branch 0 and 1 outputs must have identical types, got\n" - r"DIFFERENT ShapedArray\(float32\[1\]\) " - r"vs. ShapedArray\(float32\[\].*\)."): - lax.switch(1, [lambda _: jnp.array([1.], jnp.float32), - lambda _: jnp.float32(1.)], + "branch 0 output and branch 1 output must have identical types, got\n" + r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) " + r"vs. ShapedArray\(float32\[\].*\)'}."): + lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)), + lambda _: dict(a=jnp.float32(1.))], 1.) def testCondOneBranchConstant(self): diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 46c2f5aaf..abef6ca56 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -560,7 +560,9 @@ class SparsifyTest(jtu.JaxTestCase): func(x, y) # No error func(x_bcoo, y_bcoo) # No error - with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"): + with self.assertRaisesRegex( + TypeError, + "sparsified true_fun output must have same type structure as sparsified false_fun output.*"): func(x_bcoo, y) @parameterized.named_parameters(