Add comments for residuals from f_bwd. (#4244)

This commit is contained in:
Qiao Zhang 2020-09-10 03:58:28 -07:00 committed by GitHub
parent 0962ceb057
commit 26a53ae554
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -193,10 +193,11 @@
" return jnp.sin(x) * y\n",
"\n",
"def f_fwd(x, y):\n",
"# Returns primal output and residuals to be used in backward pass by f_bwd.\n",
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
"\n",
"def f_bwd(res, g):\n",
" cos_x, sin_x, y = res\n",
" cos_x, sin_x, y = res # Gets residuals computed in f_fwd\n",
" return (cos_x * g * y, sin_x * g)\n",
"\n",
"f.defvjp(f_fwd, f_bwd)"
@ -2820,4 +2821,4 @@
]
}
]
}
}