mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
fix checkify custom_jvp rule to handle symbolic zeros
likely broken in #15426, or maybe not quite right before either Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
2694bf6207
commit
391e95a683
@ -35,6 +35,7 @@ from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util as jtu
|
||||
from jax._src.ad_util import SymbolicZero
|
||||
from jax._src.api_util import flatten_fun
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
@ -911,7 +912,7 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
call_jaxpr.consts, enabled_errors, err_tree))
|
||||
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
|
||||
partial_checkify)
|
||||
jvp = custom_derivatives.lift_jvp(num_consts, jvp_jaxpr_thunk)
|
||||
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
|
||||
jvp, jvp_out_tree = flatten_fun_output(jvp)
|
||||
all_outs = custom_derivatives.custom_jvp_call_p.bind(
|
||||
partial_checkify, jvp, *err_vals, *in_vals, **params)
|
||||
@ -926,6 +927,32 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
return out_err, out_vals
|
||||
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
|
||||
|
||||
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
|
||||
# outputs that checkify adds (just forwarding the error data's primal and
|
||||
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
|
||||
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
|
||||
# Adding another layer of lu.transformation was tricky, though maybe doable.
|
||||
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
|
||||
@lu.wrap_init
|
||||
def jvp(*xs):
|
||||
n, ragged = divmod(len(xs), 2)
|
||||
assert not ragged
|
||||
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
|
||||
zeros = [type(t) is SymbolicZero for t in tangents]
|
||||
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
|
||||
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
|
||||
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
|
||||
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
|
||||
nz_out_tangents_ = iter(nz_out_tangents)
|
||||
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
|
||||
if z else next(nz_out_tangents_)
|
||||
for p, z in zip(out_primals, out_zeros)]
|
||||
assert next(nz_out_tangents_, None) is None
|
||||
primal_errs = xs[num_consts:num_consts+num_errs]
|
||||
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
|
||||
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
|
||||
return jvp
|
||||
|
||||
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
|
||||
symbolic_zeros):
|
||||
|
@ -817,6 +817,21 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
_ = f(3.)
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
def test_goodfellow_custom_jvp(self):
|
||||
def h(fext):
|
||||
checkify.check(True, "")
|
||||
return jax.nn.relu(fext)
|
||||
|
||||
h = checkify.checkify(h)
|
||||
|
||||
def h_out(fext):
|
||||
_, out = h(fext)
|
||||
return out
|
||||
|
||||
h_grad = jax.grad(h_out)
|
||||
h_grad(0.) # doesn't crash
|
||||
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user