1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

Add JVP rule for linear_call.

This commit is contained in:
Dan Foreman-Mackey 2025-04-08 15:35:25 -04:00
parent f7a2760822
commit e1aa83ad67
2 changed files with 27 additions and 0 deletions

@ -1447,6 +1447,19 @@ def _linear_call_impl(*args, callee, transpose_thunk, num_callee_consts,
del transpose_thunk, num_callee_consts, num_res
return core.eval_jaxpr(callee.jaxpr, (), *args)
def _linear_call_jvp_rule(primals, tangents, callee, transpose_thunk,
num_callee_consts, num_res):
consts_and_res, primals = split_list(primals, [num_callee_consts + num_res])
const_tangents, tangents = split_list(tangents, [num_callee_consts + num_res])
assert all(type(t) is Zero for t in const_tangents)
primals_out = linear_call_p.bind(
*consts_and_res, *primals, callee=callee, transpose_thunk=transpose_thunk,
num_callee_consts=num_callee_consts, num_res=num_res)
tangents_out = linear_call_p.bind(
*consts_and_res, *tangents, callee=callee, transpose_thunk=transpose_thunk,
num_callee_consts=num_callee_consts, num_res=num_res)
return primals_out, tangents_out
def _linear_call_transpose_rule(cts, *args, callee, transpose_thunk,
num_callee_consts, num_res):
transpose, t_consts = transpose_thunk()
@ -1478,6 +1491,7 @@ linear_call_p = core.Primitive('linear_call')
linear_call_p.multiple_results = True
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_jvps[linear_call_p] = _linear_call_jvp_rule
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
xla.register_initial_style_primitive(linear_call_p)
mlir.register_lowering(linear_call_p, mlir.lower_fun(

@ -10073,6 +10073,19 @@ class CustomTransposeTest(jtu.JaxTestCase):
return jax.custom_derivatives.linear_call(fn, tp, None, x)
jax.jit(f)(0.1)
def test_linear_call_grad(self):
def f(x, y):
def fn(r, x): return x / r
def tp(r, t): return t / r
return x + jax.custom_derivatives.linear_call(fn, tp, y, x)
def f_ref(x, y):
return x + x / y
x = jnp.array(6.)
y = jnp.array(3.)
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y))
def test_basic(self):
def f(x, y):
@custom_transpose(jnp.ones(2))