From 85af862efdd6f3ea2e4bf0dfbe2d9e3d2135ab90 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 30 Oct 2023 15:27:17 -0700 Subject: [PATCH] [Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time. Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035 PiperOrigin-RevId: 577974380 --- jax/_src/pjit.py | 46 +++++++++++++++++++++++++++++++------------ jax/_src/test_util.py | 21 ++++++++++++++++++++ tests/pjit_test.py | 19 ++++++++++++++++++ 3 files changed, 73 insertions(+), 13 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6e6544533..b80249dce 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1361,6 +1361,36 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_): pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) +def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, + out_shardings, api_name): + mod_ctx = ctx.module_context + axis_ctx = ctx.module_context.axis_context + da = None + if isinstance(axis_ctx, sharding_impls.ShardingContext): + da = tuple(axis_ctx.device_assignment) + elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + da = axis_ctx.mesh._flat_devices_tuple + key = (pjit_p, name, jaxpr, effects, da, + pxla.SemanticallyEqualShardings(in_shardings), + pxla.SemanticallyEqualShardings(out_shardings), api_name) + + func = mod_ctx.cached_primitive_lowerings.get(key, None) + if func is None: + arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim) + for aval, i in zip(ctx.avals_in, in_shardings)] + result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim) + for aval, o in zip(ctx.avals_out, out_shardings)] + # TODO(b/228598865): inlined calls cannot have shardings set directly on the + # inputs or outputs because they are lost during MLIR->HLO conversion. + # using_sharding_annotation=False means we add an identity operation instead. + func = mlir.lower_jaxpr_to_fun( + mod_ctx, name, jaxpr, effects, arg_shardings=arg_shardings, + result_shardings=result_shardings, use_sharding_annotations=False, + api_name=api_name) + mod_ctx.cached_primitive_lowerings[key] = func + return func + + def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, keep_unused, inline): @@ -1369,20 +1399,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, output_types = [mlir.token_type()] * len(effects) + output_types flat_output_types = flatten(output_types) - arg_shardings = [None if is_unspecified(i) else - i._to_xla_hlo_sharding(aval.ndim) - for aval, i in zip(ctx.avals_in, in_shardings)] - result_shardings = [None if is_unspecified(o) else - o._to_xla_hlo_sharding(aval.ndim) - for aval, o in zip(ctx.avals_out, out_shardings)] + func = _pjit_cached_lower_jaxpr_to_fun( + ctx, name, jaxpr, tuple(effects), in_shardings, + out_shardings, api_name=('jit' if resource_env is None else 'pjit')) - # TODO(b/228598865): inlined calls cannot have shardings set directly on the - # inputs or outputs because they are lost during MLIR->HLO conversion. - # using_sharding_annotation=False means we add an identity operation instead. - func = mlir.lower_jaxpr_to_fun( - ctx.module_context, name, jaxpr, effects, arg_shardings=arg_shardings, - result_shardings=result_shardings, use_sharding_annotations=False, - api_name=('jit' if resource_env is None else 'pjit')) tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) call = func_dialect.CallOp(flat_output_types, diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 988a1dc88..1eaa4466e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -272,6 +272,27 @@ def count_jit_and_pmap_compiles(): finally: mlir.lower_jaxpr_to_module = mlir_lower + +@contextmanager +def count_subjaxpr_to_mhlo_conversion(fun_name: str): + # No need to clear any caches since we generally jit and pmap fresh callables + # in tests. + + mlir_lower = mlir.lower_jaxpr_to_fun + count = [0] + + def mlir_lower_and_count(ctx, name, *args, **kwargs): + if name == fun_name: + count[0] += 1 + return mlir_lower(ctx, name, *args, **kwargs) + + mlir.lower_jaxpr_to_fun = mlir_lower_and_count + try: + yield count + finally: + mlir.lower_jaxpr_to_fun = mlir_lower + + @contextmanager def assert_num_jit_and_pmap_compilations(times): with count_jit_and_pmap_compiles() as count: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6f318d080..ee2f73ab3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3522,6 +3522,25 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertEqual(out2.device(), jax.devices()[0]) self.assertArraysEqual(out2, np_inp) + def test_jit_submhlo_cached(self): + @jax.jit + def nest(x): + return x * 2 + + @jax.jit + def top(x): + y = nest(x) + z = nest(y) + a = nest(z) + b = nest(a) + return b + + with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count: + top(jnp.arange(8)) + + # The count should be 1 because `nest`'s lowering to MHLO should be cached. + self.assertEqual(count[0], 1) + def test_wsc_eager(self): mesh = jtu.create_global_mesh((2,), ('x',)) np_inp = np.arange(8)