mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
Add JVP rule for linear_call.
This commit is contained in:
parent
f7a2760822
commit
e1aa83ad67
@ -1447,6 +1447,19 @@ def _linear_call_impl(*args, callee, transpose_thunk, num_callee_consts,
|
||||
del transpose_thunk, num_callee_consts, num_res
|
||||
return core.eval_jaxpr(callee.jaxpr, (), *args)
|
||||
|
||||
def _linear_call_jvp_rule(primals, tangents, callee, transpose_thunk,
|
||||
num_callee_consts, num_res):
|
||||
consts_and_res, primals = split_list(primals, [num_callee_consts + num_res])
|
||||
const_tangents, tangents = split_list(tangents, [num_callee_consts + num_res])
|
||||
assert all(type(t) is Zero for t in const_tangents)
|
||||
primals_out = linear_call_p.bind(
|
||||
*consts_and_res, *primals, callee=callee, transpose_thunk=transpose_thunk,
|
||||
num_callee_consts=num_callee_consts, num_res=num_res)
|
||||
tangents_out = linear_call_p.bind(
|
||||
*consts_and_res, *tangents, callee=callee, transpose_thunk=transpose_thunk,
|
||||
num_callee_consts=num_callee_consts, num_res=num_res)
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _linear_call_transpose_rule(cts, *args, callee, transpose_thunk,
|
||||
num_callee_consts, num_res):
|
||||
transpose, t_consts = transpose_thunk()
|
||||
@ -1478,6 +1491,7 @@ linear_call_p = core.Primitive('linear_call')
|
||||
linear_call_p.multiple_results = True
|
||||
linear_call_p.def_impl(_linear_call_impl)
|
||||
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
|
||||
ad.primitive_jvps[linear_call_p] = _linear_call_jvp_rule
|
||||
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
|
||||
xla.register_initial_style_primitive(linear_call_p)
|
||||
mlir.register_lowering(linear_call_p, mlir.lower_fun(
|
||||
|
@ -10073,6 +10073,19 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
return jax.custom_derivatives.linear_call(fn, tp, None, x)
|
||||
jax.jit(f)(0.1)
|
||||
|
||||
def test_linear_call_grad(self):
|
||||
def f(x, y):
|
||||
def fn(r, x): return x / r
|
||||
def tp(r, t): return t / r
|
||||
return x + jax.custom_derivatives.linear_call(fn, tp, y, x)
|
||||
|
||||
def f_ref(x, y):
|
||||
return x + x / y
|
||||
|
||||
x = jnp.array(6.)
|
||||
y = jnp.array(3.)
|
||||
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y))
|
||||
|
||||
def test_basic(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
|
Loading…
x
Reference in New Issue
Block a user