[jax] Fix jax_export issue with static args.

PiperOrigin-RevId: 628337221
This commit is contained in:
John QiangZhang 2024-04-26 02:11:38 -07:00 committed by jax authors
parent c176201386
commit 0b343b9ac1
2 changed files with 29 additions and 1 deletions

View File

@ -414,6 +414,9 @@ def export(fun_jax: Callable,
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may has no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if _shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:

View File

@ -213,6 +213,32 @@ class JaxExportTest(jtu.JaxTestCase):
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x), f1(x))
def test_jit_static_arg(self):
with self.subTest("static_argnames"):
@functools.partial(jax.jit, static_argnames=["c"])
def f(x, *, c):
return c * jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x, c=0.1)
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x, c=0.1), f1(x))
with self.subTest("static_argnums"):
@functools.partial(jax.jit, static_argnums=[1])
def g(x, c):
return c * jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_g = get_exported(g)(x, 0.1)
g1 = export.call_exported(exp_g)
self.assertAllClose(g(x, 0.1), g1(x))
def test_call_exported_lambda(self):
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = lambda x: jnp.sin(x)
@ -495,7 +521,6 @@ class JaxExportTest(jtu.JaxTestCase):
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
get_exported(f)(x_poly_spec, y_poly_spec)
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)