diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 22d703945..76b3fb9ec 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8422,3 +8422,13 @@ mlir.register_lowering(optimization_barrier_p, def _optimization_barrier_batcher(batched_args, batch_dims, **params): return optimization_barrier_p.bind(*batched_args, **params), batch_dims batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher + +def _opt_barrier_jvp(primals, tangents): + tangents = [ad.instantiate_zeros(t) for t in tangents] + return optimization_barrier(primals), optimization_barrier(tangents) +ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp + +def _opt_barrier_transpose(cts, *primals): + cts = [ad.instantiate_zeros(ct) for ct in cts] + return optimization_barrier(cts) +ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose diff --git a/tests/lax_test.py b/tests/lax_test.py index 4b67819be..8764caeb2 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase): x = lax.optimization_barrier((2, 3)) self.assertEqual((2, 3), x) + def test_optimization_barrier_autodiff(self): + def f(x): + y = 1. * x + x, y = lax.optimization_barrier((x, y)) + z = 2. * x + return y + z + g = jax.grad(f)(5.) # doesn't crash + self.assertAllClose(g, 3., check_dtypes=False) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected):