Add default implementation of process_custom_jvp_call and

process_custom_vjp_call to `jax.experimental.callback`
This commit is contained in:
Sharad Vikram 2020-10-16 14:53:23 -07:00
parent 51276c8ad5
commit e8901d51af
2 changed files with 26 additions and 0 deletions

View File

@ -156,3 +156,16 @@ class CallbackTrace(Trace):
f = callback_subtrace(f, self.main)
vals_out = call_primitive.bind(f, *vals_in, **params)
return [CallbackTracer(self, val) for val in vals_out]
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
# This implementation just drops the custom derivative rule.
# TODO(sharadmv): don't drop the custom derivative rule
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees):
# This implementation just drops the custom derivative rule.
# TODO(sharadmv): don't drop the custom derivative rule
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)

View File

@ -15,6 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax.experimental.callback import find_by_value, rewrite, FoundValue
import jax.numpy as jnp
@ -89,5 +90,17 @@ class CallbackTest(jtu.JaxTestCase):
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
jnp.array([4.0, 6.0]))
def testRewriteWithCustomGradients(self):
def f(x):
return jax.nn.relu(x)
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([2.0, 4.0]))
self.assertAllClose(
rewrite(f, {})(x),
jnp.array([2.0, 4.0]))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())