mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add default implementation of process_custom_jvp_call and
process_custom_vjp_call to `jax.experimental.callback`
This commit is contained in:
parent
51276c8ad5
commit
e8901d51af
@ -156,3 +156,16 @@ class CallbackTrace(Trace):
|
||||
f = callback_subtrace(f, self.main)
|
||||
vals_out = call_primitive.bind(f, *vals_in, **params)
|
||||
return [CallbackTracer(self, val) for val in vals_out]
|
||||
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
||||
# This implementation just drops the custom derivative rule.
|
||||
# TODO(sharadmv): don't drop the custom derivative rule
|
||||
del primitive, jvp # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
||||
out_trees):
|
||||
# This implementation just drops the custom derivative rule.
|
||||
# TODO(sharadmv): don't drop the custom derivative rule
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
@ -15,6 +15,7 @@
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental.callback import find_by_value, rewrite, FoundValue
|
||||
import jax.numpy as jnp
|
||||
@ -89,5 +90,17 @@ class CallbackTest(jtu.JaxTestCase):
|
||||
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
|
||||
jnp.array([4.0, 6.0]))
|
||||
|
||||
def testRewriteWithCustomGradients(self):
|
||||
def f(x):
|
||||
return jax.nn.relu(x)
|
||||
|
||||
x = jnp.array([2.0, 4.0])
|
||||
self.assertAllClose(f(x), jnp.array([2.0, 4.0]))
|
||||
|
||||
self.assertAllClose(
|
||||
rewrite(f, {})(x),
|
||||
jnp.array([2.0, 4.0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user