[export] Fix poly shape check for vjp function with integer valued, polymorphic output.

PiperOrigin-RevId: 650990009
This commit is contained in:
Tom Ward 2024-07-10 06:11:40 -07:00 committed by jax authors
parent ebfbd8ac0c
commit 33bd2925f0
2 changed files with 9 additions and 3 deletions

View File

@ -1142,7 +1142,7 @@ def call(exported: Exported) -> Callable[..., jax.Array]:
def fix_float0_ct(ct_res, expected_aval):
if expected_aval.dtype != dtypes.float0:
return ct_res
return ad_util.zeros_like_aval(expected_aval)
return ad_util.zeros_like_jaxval(ct_res)
ct_res_fixed = map(fix_float0_ct,
ct_res_flat, exp_vjp.in_avals[len(args_flat):])

View File

@ -462,7 +462,9 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(jax.grad(jax.grad(jax.grad(f)))(x),
jax.grad(jax.grad(jax.grad(f1)))(x))
def test_grad_int(self):
@jtu.parameterized_filterable(
kwargs=[dict(poly_shape=True), dict(poly_shape=False)])
def test_grad_int(self, poly_shape):
def f(xi, xf):
return (2 * xi.T, xf.T * xf.T)
@ -480,7 +482,11 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(res, (xi_ct, xf_ct))
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf)
if poly_shape:
args = export.symbolic_args_specs([xi, xf], shapes_specs=["2, a", "a, 4"])
else:
args = (xi, xf)
exp = get_exported(jax.jit(f), vjp_order=2)(*args)
fr = exp.call
res = fr(xi, xf)