Merge pull request #14027 from mattjj:issue14026

PiperOrigin-RevId: 502408481
This commit is contained in:
jax authors 2023-01-16 10:59:46 -08:00
commit 432b909ef8
2 changed files with 42 additions and 0 deletions

View File

@ -687,7 +687,9 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
del linear # could use for error checking, but see #14026
index, *ops = args
linear = [type(x) is ad.UndefinedPrimal for x in ops]
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in

View File

@ -2564,5 +2564,45 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.grad(f)(1.) # doesn't crash
def test_custom_jvp_tangent_cond_transpose(self):
# https://github.com/google/jax/issues/14026
def mask_fun(arr, choice):
out = (1 - choice) * arr.sum() + choice * (1 - arr.sum())
return out
def switch_fun(arr, choice):
choice = jnp.floor(choice).astype(jnp.int32)
out = jax.lax.switch(choice, [lambda x: x.sum(), lambda x: 1 - x.sum()], arr)
return out
test_arr = jnp.arange(3.)
test_val = 0.
expected1 = jax.grad(mask_fun)(test_arr, test_val)
expected2 = jax.grad(switch_fun)(test_arr, test_val)
def good_switchfun_jvp(primals, tangents):
arr, choice = primals
arr_dot, choice_dot = tangents
return switch_fun(arr, choice), mask_fun(arr_dot, choice)
def bad_switchfun_jvp(primals, tangents):
arr, choice = primals
arr_dot, choice_dot = tangents
return switch_fun(arr, choice), switch_fun(arr_dot, choice)
good_custom_switchfun = jax.custom_jvp(switch_fun)
good_custom_switchfun.defjvp(good_switchfun_jvp)
expected3 = jax.grad(good_custom_switchfun)(test_arr, test_val)
bad_custom_switchfun = jax.custom_jvp(switch_fun)
bad_custom_switchfun.defjvp(bad_switchfun_jvp)
actual = jax.grad(bad_custom_switchfun)(test_arr, test_val)
self.assertAllClose(expected1, expected2)
self.assertAllClose(expected2, expected3)
self.assertAllClose(expected3, actual)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())