mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove static_argnums from AOT invocation.
Static args are not needed during invoking an AOT computation. PiperOrigin-RevId: 486698420
This commit is contained in:
parent
da519f3b2c
commit
b127b70e30
@ -172,9 +172,8 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
return wraps(fun)(cpp_pjit_f)
|
||||
|
||||
class _CppPjitAotCall:
|
||||
def __init__(self, fun: Callable, static_argnums: Any):
|
||||
def __init__(self, fun: Callable):
|
||||
self._fun = fun
|
||||
self._static_argnums = static_argnums
|
||||
|
||||
def __call__(self, params: CompiledCallParams):
|
||||
|
||||
@ -206,8 +205,7 @@ class _CppPjitAotCall:
|
||||
|
||||
return outs, fastpath_data
|
||||
|
||||
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss,
|
||||
self._static_argnums)
|
||||
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss, [])
|
||||
return self._cpp_aot_pjit_f
|
||||
|
||||
|
||||
@ -472,7 +470,9 @@ def pjit(fun: Callable,
|
||||
in_is_global, always_lower=True)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
|
||||
create_cpp_call = _CppPjitAotCall(fun, static_argnums)
|
||||
# This is only used for execution of the compiled object. It is not used
|
||||
# for lowering.
|
||||
create_cpp_call = _CppPjitAotCall(fun)
|
||||
else:
|
||||
create_cpp_call = None
|
||||
|
||||
|
@ -2627,6 +2627,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = multihost_utils.global_array_to_host_local_array(arr, mesh, pspec)
|
||||
self.assertEqual(id(arr), id(out))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jax_array(True)
|
||||
def testLowerCompileWithStaticArguments(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=P(('x', 'y'),),
|
||||
out_axis_resources=P(('x', 'y'),), static_argnums=0)
|
||||
def f(c, x):
|
||||
return x if c == 0 else x + 1
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
exe = f.lower(1, x).compile()
|
||||
|
||||
self.assertAllClose(exe(x), x + 1, check_dtypes=False)
|
||||
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user