mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035 PiperOrigin-RevId: 577974380
This commit is contained in:
parent
7eddc76a03
commit
85af862efd
@ -1361,6 +1361,36 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_):
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
out_shardings, api_name):
|
||||
mod_ctx = ctx.module_context
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
da = None
|
||||
if isinstance(axis_ctx, sharding_impls.ShardingContext):
|
||||
da = tuple(axis_ctx.device_assignment)
|
||||
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
da = axis_ctx.mesh._flat_devices_tuple
|
||||
key = (pjit_p, name, jaxpr, effects, da,
|
||||
pxla.SemanticallyEqualShardings(in_shardings),
|
||||
pxla.SemanticallyEqualShardings(out_shardings), api_name)
|
||||
|
||||
func = mod_ctx.cached_primitive_lowerings.get(key, None)
|
||||
if func is None:
|
||||
arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim)
|
||||
for aval, i in zip(ctx.avals_in, in_shardings)]
|
||||
result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim)
|
||||
for aval, o in zip(ctx.avals_out, out_shardings)]
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
# using_sharding_annotation=False means we add an identity operation instead.
|
||||
func = mlir.lower_jaxpr_to_fun(
|
||||
mod_ctx, name, jaxpr, effects, arg_shardings=arg_shardings,
|
||||
result_shardings=result_shardings, use_sharding_annotations=False,
|
||||
api_name=api_name)
|
||||
mod_ctx.cached_primitive_lowerings[key] = func
|
||||
return func
|
||||
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
keep_unused, inline):
|
||||
@ -1369,20 +1399,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
flat_output_types = flatten(output_types)
|
||||
|
||||
arg_shardings = [None if is_unspecified(i) else
|
||||
i._to_xla_hlo_sharding(aval.ndim)
|
||||
for aval, i in zip(ctx.avals_in, in_shardings)]
|
||||
result_shardings = [None if is_unspecified(o) else
|
||||
o._to_xla_hlo_sharding(aval.ndim)
|
||||
for aval, o in zip(ctx.avals_out, out_shardings)]
|
||||
func = _pjit_cached_lower_jaxpr_to_fun(
|
||||
ctx, name, jaxpr, tuple(effects), in_shardings,
|
||||
out_shardings, api_name=('jit' if resource_env is None else 'pjit'))
|
||||
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
# using_sharding_annotation=False means we add an identity operation instead.
|
||||
func = mlir.lower_jaxpr_to_fun(
|
||||
ctx.module_context, name, jaxpr, effects, arg_shardings=arg_shardings,
|
||||
result_shardings=result_shardings, use_sharding_annotations=False,
|
||||
api_name=('jit' if resource_env is None else 'pjit'))
|
||||
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
|
||||
args = (*ctx.dim_var_values, *tokens_in, *args)
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
|
@ -272,6 +272,27 @@ def count_jit_and_pmap_compiles():
|
||||
finally:
|
||||
mlir.lower_jaxpr_to_module = mlir_lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_subjaxpr_to_mhlo_conversion(fun_name: str):
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
mlir_lower = mlir.lower_jaxpr_to_fun
|
||||
count = [0]
|
||||
|
||||
def mlir_lower_and_count(ctx, name, *args, **kwargs):
|
||||
if name == fun_name:
|
||||
count[0] += 1
|
||||
return mlir_lower(ctx, name, *args, **kwargs)
|
||||
|
||||
mlir.lower_jaxpr_to_fun = mlir_lower_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
mlir.lower_jaxpr_to_fun = mlir_lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_num_jit_and_pmap_compilations(times):
|
||||
with count_jit_and_pmap_compiles() as count:
|
||||
|
@ -3522,6 +3522,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.device(), jax.devices()[0])
|
||||
self.assertArraysEqual(out2, np_inp)
|
||||
|
||||
def test_jit_submhlo_cached(self):
|
||||
@jax.jit
|
||||
def nest(x):
|
||||
return x * 2
|
||||
|
||||
@jax.jit
|
||||
def top(x):
|
||||
y = nest(x)
|
||||
z = nest(y)
|
||||
a = nest(z)
|
||||
b = nest(a)
|
||||
return b
|
||||
|
||||
with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count:
|
||||
top(jnp.arange(8))
|
||||
|
||||
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_wsc_eager(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user