Share lowering code between jit and aot jit path

PiperOrigin-RevId: 622487044
This commit is contained in:
Yash Katariya 2024-04-06 13:43:32 -07:00 committed by jax authors
parent e8b86cd81d
commit 3b5980fd73

View File

@ -487,20 +487,9 @@ def _make_jit_wrapper(jit_info: PjitInfo):
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, arg_names, ()) = _infer_params(new_jit_info, args, kwargs)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
in_layouts = _resolve_in_layouts(
args_flat, params['in_layouts'], in_shardings,
params['jaxpr'].in_avals)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
in_layouts, params['out_layouts'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'],
lowering_parameters=lowering_parameters)
lowering = _resolve_and_lower(
args_flat, **params, lowering_parameters=lowering_parameters)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
@ -1490,20 +1479,33 @@ def _resolve_in_shardings(
return tuple(resolved_in_shardings)
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_parameters):
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
lowered = _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
lowering_parameters=lowering_parameters)
return lowered
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
global _most_recent_pjit_call_executable
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals)
compiled = _resolve_and_lower(
args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings,
in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
lowering_parameters=mlir.LoweringParameters()).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.enable_checks.value: