fix checkify custom_jvp rule to handle symbolic zeros

likely broken in #15426, or maybe not quite right before either

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2023-05-09 14:11:36 -07:00
parent 2694bf6207
commit 391e95a683
2 changed files with 43 additions and 1 deletions

View File

@ -35,6 +35,7 @@ from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.config import config
from jax._src.interpreters import ad
@ -911,7 +912,7 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
call_jaxpr.consts, enabled_errors, err_tree))
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
jvp = custom_derivatives.lift_jvp(num_consts, jvp_jaxpr_thunk)
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
jvp, jvp_out_tree = flatten_fun_output(jvp)
all_outs = custom_derivatives.custom_jvp_call_p.bind(
partial_checkify, jvp, *err_vals, *in_vals, **params)
@ -926,6 +927,32 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
return out_err, out_vals
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
@lu.wrap_init
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
primal_errs = xs[num_consts:num_consts+num_errs]
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
return jvp
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
symbolic_zeros):

View File

@ -817,6 +817,21 @@ class CheckifyTransformTests(jtu.JaxTestCase):
_ = f(3.)
self.assertEqual(count[0], 0)
def test_goodfellow_custom_jvp(self):
def h(fext):
checkify.check(True, "")
return jax.nn.relu(fext)
h = checkify.checkify(h)
def h_out(fext):
_, out = h(fext)
return out
h_grad = jax.grad(h_out)
h_grad(0.) # doesn't crash
@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):