Handle effects in lax.custom_linear_solve.

This commit is contained in:
Dan Foreman-Mackey 2025-02-03 11:14:48 -05:00
parent 7164c6ba3e
commit d42e3650d0
2 changed files with 22 additions and 2 deletions

View File

@ -325,7 +325,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return args_to_raise
return args_to_raise, jaxprs.solve.effects
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
@ -482,7 +482,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
linear_solve_p = core.Primitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(

View File

@ -482,6 +482,26 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
# doesn't crash
jax.vmap(solve_aux)(b)
def test_custom_linear_solve_ordered_effects(self):
# See https://github.com/jax-ml/jax/issues/26087
def mat_vec(v):
jax.debug.callback(lambda: print("mat_vec"), ordered=True)
return v
def solve(b):
return lax.custom_linear_solve(mat_vec, b, lambda matvec, x: matvec(x))
b = self.rng().randn(24)
with jtu.capture_stdout() as output:
expected = solve(b)
jax.effects_barrier()
self.assertEqual(output(), "mat_vec\n")
with jtu.capture_stdout() as output:
computed = jax.jit(solve)(b)
jax.effects_barrier()
self.assertEqual(output(), "mat_vec\n")
self.assertAllClose(computed, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())