fix custom_jvp check for tracers in arguments marked nondiff_argnums

PiperOrigin-RevId: 520379098
This commit is contained in:
Roy Frostig 2023-03-29 10:41:34 -07:00 committed by jax authors
parent f282c251d4
commit c8a7d5990d

View File

@ -568,7 +568,7 @@ class custom_vjp(Generic[ReturnValue]):
def _check_for_tracers(x):
for leaf in tree_leaves(x):
if isinstance(x, core.Tracer):
if isinstance(leaf, core.Tracer):
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
"function in a position indicated by nondiff_argnums as "
"non-differentiable. Tracers cannot be passed as non-differentiable "