mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Handle effects in lax.custom_linear_solve.
This commit is contained in:
parent
7164c6ba3e
commit
d42e3650d0
@ -325,7 +325,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
||||
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
||||
if num_aux > 0:
|
||||
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
||||
return args_to_raise
|
||||
return args_to_raise, jaxprs.solve.effects
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
@ -482,7 +482,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
|
||||
linear_solve_p = core.Primitive('custom_linear_solve')
|
||||
linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.register_initial_style_primitive(linear_solve_p)
|
||||
mlir.register_lowering(
|
||||
|
@ -482,6 +482,26 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
|
||||
# doesn't crash
|
||||
jax.vmap(solve_aux)(b)
|
||||
|
||||
def test_custom_linear_solve_ordered_effects(self):
|
||||
# See https://github.com/jax-ml/jax/issues/26087
|
||||
def mat_vec(v):
|
||||
jax.debug.callback(lambda: print("mat_vec"), ordered=True)
|
||||
return v
|
||||
|
||||
def solve(b):
|
||||
return lax.custom_linear_solve(mat_vec, b, lambda matvec, x: matvec(x))
|
||||
|
||||
b = self.rng().randn(24)
|
||||
with jtu.capture_stdout() as output:
|
||||
expected = solve(b)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "mat_vec\n")
|
||||
with jtu.capture_stdout() as output:
|
||||
computed = jax.jit(solve)(b)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "mat_vec\n")
|
||||
self.assertAllClose(computed, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user