mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140 Previously, the following code: ``` def f(i, x): return lax.switch(i, [lambda x: dict(a=x), lambda x: dict(a=(x, x))], x) f(0, 42) ``` resulted in the error message: ``` TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}). ``` With this change the error message is more specific where the difference is in the pytree structure: ``` TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences: * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ ```
This commit is contained in:
parent
640cb009f1
commit
c2adfbf1c2
@ -30,7 +30,8 @@ from jax._src import util
|
||||
from jax._src.util import weakref_lru_cache, safe_map, partition_list
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_map, tree_unflatten
|
||||
from jax.tree_util import tree_map, tree_unflatten, keystr
|
||||
from jax._src.tree_util import equality_errors_pytreedef
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
@ -188,20 +189,28 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
|
||||
jaxpr = jaxpr.replace(effects=effects)
|
||||
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2):
|
||||
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
||||
|
||||
Corresponding `tree` and `avals` must match in the sense that the number of
|
||||
leaves in `tree` must be equal to the length of `avals`. `what` will be
|
||||
prepended to details of the mismatch in TypeError.
|
||||
leaves in `tree` must be equal to the length of `avals`. `what1` and
|
||||
`what2` describe what the `tree1` and `tree2` represent.
|
||||
"""
|
||||
if tree1 != tree2:
|
||||
raise TypeError(
|
||||
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
||||
errs = list(equality_errors_pytreedef(tree1, tree2))
|
||||
msg = []
|
||||
msg.append(
|
||||
f"{what1} must have same type structure as {what2}, but there are differences: ")
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
msg.append(
|
||||
f" * at output{keystr(tuple(path))}, {what1} has {thing1} and "
|
||||
f"{what2} has {thing2}, so {explanation}")
|
||||
raise TypeError('\n'.join(msg))
|
||||
|
||||
if not all(map(core.typematch, avals1, avals2)):
|
||||
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
||||
tree_unflatten(tree2, avals2))
|
||||
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
||||
raise TypeError(f"{what1} and {what2} must have identical types, got\n{diff}.")
|
||||
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
|
||||
if has_aux:
|
||||
|
@ -147,8 +147,9 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
_check_tree_and_avals("branch 0 output",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
f"branch {i + 1} output",
|
||||
out_tree, jaxpr.out_avals)
|
||||
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
||||
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
||||
@ -250,8 +251,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
true_jaxpr.out_avals + false_jaxpr.out_avals):
|
||||
raise ValueError("Cannot return `Ref`s from `cond`.")
|
||||
|
||||
_check_tree_and_avals("true_fun and false_fun output",
|
||||
_check_tree_and_avals("true_fun output",
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
"false_fun output",
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
# prune passhtrough outputs
|
||||
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
|
||||
|
@ -843,8 +843,9 @@ sparse_rules_bcoo[lax.scan_p] = _scan_sparse
|
||||
def _cond_sparse(spenv, pred, *operands, branches, **params):
|
||||
sp_branches, treedefs = zip(*(_sparsify_jaxpr(spenv, jaxpr, *operands)
|
||||
for jaxpr in branches))
|
||||
_check_tree_and_avals("sparsified true_fun and false_fun output",
|
||||
_check_tree_and_avals("sparsified true_fun output",
|
||||
treedefs[0], sp_branches[0].out_avals,
|
||||
"sparsified false_fun output",
|
||||
treedefs[1], sp_branches[1].out_avals)
|
||||
args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands)))
|
||||
out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params)
|
||||
|
@ -995,12 +995,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
|
||||
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("true_fun and false_fun output must have same type structure, "
|
||||
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
|
||||
lax.cond(True, lambda top: 2., lambda fop: (3., 3.), 1.)
|
||||
re.compile("true_fun output must have same type structure "
|
||||
"as false_fun output, but there are differences:.*"
|
||||
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
|
||||
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"true_fun and false_fun output must have identical types, got\n"
|
||||
"true_fun output and false_fun output must have identical types, got\n"
|
||||
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
|
||||
r"ShapedArray\(float32\[\].*\)."):
|
||||
lax.cond(True,
|
||||
@ -1023,16 +1024,17 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
re.escape("Empty branch sequence")):
|
||||
lax.switch(0, [], 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("branch 0 and 1 outputs must have same type structure, "
|
||||
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
|
||||
lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
|
||||
re.compile("branch 0 output must have same type structure "
|
||||
"as branch 1 output, but there are differences:.*"
|
||||
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
|
||||
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"branch 0 and 1 outputs must have identical types, got\n"
|
||||
r"DIFFERENT ShapedArray\(float32\[1\]\) "
|
||||
r"vs. ShapedArray\(float32\[\].*\)."):
|
||||
lax.switch(1, [lambda _: jnp.array([1.], jnp.float32),
|
||||
lambda _: jnp.float32(1.)],
|
||||
"branch 0 output and branch 1 output must have identical types, got\n"
|
||||
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
|
||||
r"vs. ShapedArray\(float32\[\].*\)'}."):
|
||||
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
|
||||
lambda _: dict(a=jnp.float32(1.))],
|
||||
1.)
|
||||
|
||||
def testCondOneBranchConstant(self):
|
||||
|
@ -560,7 +560,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
func(x, y) # No error
|
||||
func(x_bcoo, y_bcoo) # No error
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"sparsified true_fun output must have same type structure as sparsified false_fun output.*"):
|
||||
func(x_bcoo, y)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user