[Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.

Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035

PiperOrigin-RevId: 577974380
This commit is contained in:
Yash Katariya 2023-10-30 15:27:17 -07:00 committed by jax authors
parent 7eddc76a03
commit 85af862efd
3 changed files with 73 additions and 13 deletions

View File

@ -1361,6 +1361,36 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_):
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
out_shardings, api_name):
mod_ctx = ctx.module_context
axis_ctx = ctx.module_context.axis_context
da = None
if isinstance(axis_ctx, sharding_impls.ShardingContext):
da = tuple(axis_ctx.device_assignment)
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
da = axis_ctx.mesh._flat_devices_tuple
key = (pjit_p, name, jaxpr, effects, da,
pxla.SemanticallyEqualShardings(in_shardings),
pxla.SemanticallyEqualShardings(out_shardings), api_name)
func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
func = mlir.lower_jaxpr_to_fun(
mod_ctx, name, jaxpr, effects, arg_shardings=arg_shardings,
result_shardings=result_shardings, use_sharding_annotations=False,
api_name=api_name)
mod_ctx.cached_primitive_lowerings[key] = func
return func
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
keep_unused, inline):
@ -1369,20 +1399,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
output_types = [mlir.token_type()] * len(effects) + output_types
flat_output_types = flatten(output_types)
arg_shardings = [None if is_unspecified(i) else
i._to_xla_hlo_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if is_unspecified(o) else
o._to_xla_hlo_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
func = _pjit_cached_lower_jaxpr_to_fun(
ctx, name, jaxpr, tuple(effects), in_shardings,
out_shardings, api_name=('jit' if resource_env is None else 'pjit'))
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
func = mlir.lower_jaxpr_to_fun(
ctx.module_context, name, jaxpr, effects, arg_shardings=arg_shardings,
result_shardings=result_shardings, use_sharding_annotations=False,
api_name=('jit' if resource_env is None else 'pjit'))
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
args = (*ctx.dim_var_values, *tokens_in, *args)
call = func_dialect.CallOp(flat_output_types,

View File

@ -272,6 +272,27 @@ def count_jit_and_pmap_compiles():
finally:
mlir.lower_jaxpr_to_module = mlir_lower
@contextmanager
def count_subjaxpr_to_mhlo_conversion(fun_name: str):
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
mlir_lower = mlir.lower_jaxpr_to_fun
count = [0]
def mlir_lower_and_count(ctx, name, *args, **kwargs):
if name == fun_name:
count[0] += 1
return mlir_lower(ctx, name, *args, **kwargs)
mlir.lower_jaxpr_to_fun = mlir_lower_and_count
try:
yield count
finally:
mlir.lower_jaxpr_to_fun = mlir_lower
@contextmanager
def assert_num_jit_and_pmap_compilations(times):
with count_jit_and_pmap_compiles() as count:

View File

@ -3522,6 +3522,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out2.device(), jax.devices()[0])
self.assertArraysEqual(out2, np_inp)
def test_jit_submhlo_cached(self):
@jax.jit
def nest(x):
return x * 2
@jax.jit
def top(x):
y = nest(x)
z = nest(y)
a = nest(z)
b = nest(a)
return b
with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count:
top(jnp.arange(8))
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
self.assertEqual(count[0], 1)
def test_wsc_eager(self):
mesh = jtu.create_global_mesh((2,), ('x',))
np_inp = np.arange(8)