mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
add experimental lax.optimization_barrier autodiff rules
This commit is contained in:
parent
95791fa9e4
commit
dadc68b6c1
@ -8422,3 +8422,13 @@ mlir.register_lowering(optimization_barrier_p,
|
||||
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
||||
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
||||
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
||||
|
||||
def _opt_barrier_jvp(primals, tangents):
|
||||
tangents = [ad.instantiate_zeros(t) for t in tangents]
|
||||
return optimization_barrier(primals), optimization_barrier(tangents)
|
||||
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
|
||||
|
||||
def _opt_barrier_transpose(cts, *primals):
|
||||
cts = [ad.instantiate_zeros(ct) for ct in cts]
|
||||
return optimization_barrier(cts)
|
||||
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose
|
||||
|
@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase):
|
||||
x = lax.optimization_barrier((2, 3))
|
||||
self.assertEqual((2, 3), x)
|
||||
|
||||
def test_optimization_barrier_autodiff(self):
|
||||
def f(x):
|
||||
y = 1. * x
|
||||
x, y = lax.optimization_barrier((x, y))
|
||||
z = 2. * x
|
||||
return y + z
|
||||
g = jax.grad(f)(5.) # doesn't crash
|
||||
self.assertAllClose(g, 3., check_dtypes=False)
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
|
Loading…
x
Reference in New Issue
Block a user