add experimental lax.optimization_barrier autodiff rules

This commit is contained in:
Matthew Johnson 2025-03-14 22:40:41 +00:00
parent 95791fa9e4
commit dadc68b6c1
2 changed files with 19 additions and 0 deletions

View File

@ -8422,3 +8422,13 @@ mlir.register_lowering(optimization_barrier_p,
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
def _opt_barrier_jvp(primals, tangents):
tangents = [ad.instantiate_zeros(t) for t in tangents]
return optimization_barrier(primals), optimization_barrier(tangents)
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
def _opt_barrier_transpose(cts, *primals):
cts = [ad.instantiate_zeros(ct) for ct in cts]
return optimization_barrier(cts)
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose

View File

@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase):
x = lax.optimization_barrier((2, 3))
self.assertEqual((2, 3), x)
def test_optimization_barrier_autodiff(self):
def f(x):
y = 1. * x
x, y = lax.optimization_barrier((x, y))
z = 2. * x
return y + z
g = jax.grad(f)(5.) # doesn't crash
self.assertAllClose(g, 3., check_dtypes=False)
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):