mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[jax] Fix jax_export issue with static args.
PiperOrigin-RevId: 628337221
This commit is contained in:
parent
c176201386
commit
0b343b9ac1
@ -414,6 +414,9 @@ def export(fun_jax: Callable,
|
||||
|
||||
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
|
||||
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
|
||||
# Static args may has no `shape` attribute.
|
||||
if not hasattr(aval, "shape"):
|
||||
continue
|
||||
for d in aval.shape:
|
||||
if _shape_poly.is_symbolic_dim(d):
|
||||
if symbolic_scope is None:
|
||||
|
@ -213,6 +213,32 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_jit_static_arg(self):
|
||||
|
||||
with self.subTest("static_argnames"):
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["c"])
|
||||
def f(x, *, c):
|
||||
return c * jnp.sin(x)
|
||||
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x, c=0.1)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x, c=0.1), f1(x))
|
||||
|
||||
with self.subTest("static_argnums"):
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=[1])
|
||||
def g(x, c):
|
||||
return c * jnp.sin(x)
|
||||
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_g = get_exported(g)(x, 0.1)
|
||||
|
||||
g1 = export.call_exported(exp_g)
|
||||
self.assertAllClose(g(x, 0.1), g1(x))
|
||||
|
||||
def test_call_exported_lambda(self):
|
||||
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
|
||||
f = lambda x: jnp.sin(x)
|
||||
@ -495,7 +521,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
|
||||
get_exported(f)(x_poly_spec, y_poly_spec)
|
||||
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
|
Loading…
x
Reference in New Issue
Block a user