mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14027 from mattjj:issue14026
PiperOrigin-RevId: 502408481
This commit is contained in:
commit
432b909ef8
@ -687,7 +687,9 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
||||
|
||||
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
del linear # could use for error checking, but see #14026
|
||||
index, *ops = args
|
||||
linear = [type(x) is ad.UndefinedPrimal for x in ops]
|
||||
in_avals = map(raise_to_shaped, branches[0].in_avals)
|
||||
num_res = len(ops) - sum(linear)
|
||||
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
|
||||
|
@ -2564,5 +2564,45 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(f)(1.) # doesn't crash
|
||||
|
||||
def test_custom_jvp_tangent_cond_transpose(self):
|
||||
# https://github.com/google/jax/issues/14026
|
||||
def mask_fun(arr, choice):
|
||||
out = (1 - choice) * arr.sum() + choice * (1 - arr.sum())
|
||||
return out
|
||||
|
||||
def switch_fun(arr, choice):
|
||||
choice = jnp.floor(choice).astype(jnp.int32)
|
||||
out = jax.lax.switch(choice, [lambda x: x.sum(), lambda x: 1 - x.sum()], arr)
|
||||
return out
|
||||
|
||||
test_arr = jnp.arange(3.)
|
||||
test_val = 0.
|
||||
|
||||
expected1 = jax.grad(mask_fun)(test_arr, test_val)
|
||||
expected2 = jax.grad(switch_fun)(test_arr, test_val)
|
||||
|
||||
def good_switchfun_jvp(primals, tangents):
|
||||
arr, choice = primals
|
||||
arr_dot, choice_dot = tangents
|
||||
return switch_fun(arr, choice), mask_fun(arr_dot, choice)
|
||||
|
||||
def bad_switchfun_jvp(primals, tangents):
|
||||
arr, choice = primals
|
||||
arr_dot, choice_dot = tangents
|
||||
return switch_fun(arr, choice), switch_fun(arr_dot, choice)
|
||||
|
||||
good_custom_switchfun = jax.custom_jvp(switch_fun)
|
||||
good_custom_switchfun.defjvp(good_switchfun_jvp)
|
||||
expected3 = jax.grad(good_custom_switchfun)(test_arr, test_val)
|
||||
|
||||
bad_custom_switchfun = jax.custom_jvp(switch_fun)
|
||||
bad_custom_switchfun.defjvp(bad_switchfun_jvp)
|
||||
actual = jax.grad(bad_custom_switchfun)(test_arr, test_val)
|
||||
|
||||
self.assertAllClose(expected1, expected2)
|
||||
self.assertAllClose(expected2, expected3)
|
||||
self.assertAllClose(expected3, actual)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user