Remove static_argnums from AOT invocation.

Static args are not needed during invoking an AOT computation.

PiperOrigin-RevId: 486698420
This commit is contained in:
Kuangyuan Chen 2022-11-07 10:21:13 -08:00 committed by jax authors
parent da519f3b2c
commit b127b70e30
2 changed files with 21 additions and 5 deletions

View File

@ -172,9 +172,8 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
return wraps(fun)(cpp_pjit_f)
class _CppPjitAotCall:
def __init__(self, fun: Callable, static_argnums: Any):
def __init__(self, fun: Callable):
self._fun = fun
self._static_argnums = static_argnums
def __call__(self, params: CompiledCallParams):
@ -206,8 +205,7 @@ class _CppPjitAotCall:
return outs, fastpath_data
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss,
self._static_argnums)
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss, [])
return self._cpp_aot_pjit_f
@ -472,7 +470,9 @@ def pjit(fun: Callable,
in_is_global, always_lower=True)
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
create_cpp_call = _CppPjitAotCall(fun, static_argnums)
# This is only used for execution of the compiled object. It is not used
# for lowering.
create_cpp_call = _CppPjitAotCall(fun)
else:
create_cpp_call = None

View File

@ -2627,6 +2627,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = multihost_utils.global_array_to_host_local_array(arr, mesh, pspec)
self.assertEqual(id(arr), id(out))
@jtu.with_mesh([('x', 2), ('y', 2)])
@jax_array(True)
def testLowerCompileWithStaticArguments(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),), static_argnums=0)
def f(c, x):
return x if c == 0 else x + 1
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(1, x).compile()
self.assertAllClose(exe(x), x + 1, check_dtypes=False)
class TempSharding(Sharding):