mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[export] Fix poly shape check for vjp function with integer valued, polymorphic output.
PiperOrigin-RevId: 650990009
This commit is contained in:
parent
ebfbd8ac0c
commit
33bd2925f0
@ -1142,7 +1142,7 @@ def call(exported: Exported) -> Callable[..., jax.Array]:
|
||||
def fix_float0_ct(ct_res, expected_aval):
|
||||
if expected_aval.dtype != dtypes.float0:
|
||||
return ct_res
|
||||
return ad_util.zeros_like_aval(expected_aval)
|
||||
return ad_util.zeros_like_jaxval(ct_res)
|
||||
|
||||
ct_res_fixed = map(fix_float0_ct,
|
||||
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
|
||||
|
@ -462,7 +462,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.grad(jax.grad(jax.grad(f)))(x),
|
||||
jax.grad(jax.grad(jax.grad(f1)))(x))
|
||||
|
||||
def test_grad_int(self):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[dict(poly_shape=True), dict(poly_shape=False)])
|
||||
def test_grad_int(self, poly_shape):
|
||||
def f(xi, xf):
|
||||
return (2 * xi.T, xf.T * xf.T)
|
||||
|
||||
@ -480,7 +482,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(res, (xi_ct, xf_ct))
|
||||
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
|
||||
|
||||
exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf)
|
||||
if poly_shape:
|
||||
args = export.symbolic_args_specs([xi, xf], shapes_specs=["2, a", "a, 4"])
|
||||
else:
|
||||
args = (xi, xf)
|
||||
exp = get_exported(jax.jit(f), vjp_order=2)(*args)
|
||||
fr = exp.call
|
||||
|
||||
res = fr(xi, xf)
|
||||
|
Loading…
x
Reference in New Issue
Block a user