mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix custom_jvp
check for tracers in arguments marked nondiff_argnums
PiperOrigin-RevId: 520379098
This commit is contained in:
parent
f282c251d4
commit
c8a7d5990d
@ -568,7 +568,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
|
||||
def _check_for_tracers(x):
|
||||
for leaf in tree_leaves(x):
|
||||
if isinstance(x, core.Tracer):
|
||||
if isinstance(leaf, core.Tracer):
|
||||
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
|
||||
"function in a position indicated by nondiff_argnums as "
|
||||
"non-differentiable. Tracers cannot be passed as non-differentiable "
|
||||
|
Loading…
x
Reference in New Issue
Block a user