mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add comments for residuals from f_bwd. (#4244)
This commit is contained in:
parent
0962ceb057
commit
26a53ae554
@ -193,10 +193,11 @@
|
||||
" return jnp.sin(x) * y\n",
|
||||
"\n",
|
||||
"def f_fwd(x, y):\n",
|
||||
"# Returns primal output and residuals to be used in backward pass by f_bwd.\n",
|
||||
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
|
||||
"\n",
|
||||
"def f_bwd(res, g):\n",
|
||||
" cos_x, sin_x, y = res\n",
|
||||
" cos_x, sin_x, y = res # Gets residuals computed in f_fwd\n",
|
||||
" return (cos_x * g * y, sin_x * g)\n",
|
||||
"\n",
|
||||
"f.defvjp(f_fwd, f_bwd)"
|
||||
@ -2820,4 +2821,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user