diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 29620cb50..a4cc36982 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -325,7 +325,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise + return args_to_raise, jaxprs.solve.effects def _custom_linear_solve_impl(*args, const_lengths, jaxprs): @@ -482,7 +482,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs): linear_solve_p = core.Primitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) -linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) +linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp xla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 857dc34d4..8dfb338f3 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -482,6 +482,26 @@ class CustomLinearSolveTest(jtu.JaxTestCase): # doesn't crash jax.vmap(solve_aux)(b) + def test_custom_linear_solve_ordered_effects(self): + # See https://github.com/jax-ml/jax/issues/26087 + def mat_vec(v): + jax.debug.callback(lambda: print("mat_vec"), ordered=True) + return v + + def solve(b): + return lax.custom_linear_solve(mat_vec, b, lambda matvec, x: matvec(x)) + + b = self.rng().randn(24) + with jtu.capture_stdout() as output: + expected = solve(b) + jax.effects_barrier() + self.assertEqual(output(), "mat_vec\n") + with jtu.capture_stdout() as output: + computed = jax.jit(solve)(b) + jax.effects_barrier() + self.assertEqual(output(), "mat_vec\n") + self.assertAllClose(computed, expected) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())