mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Share lowering code between jit and aot jit path
PiperOrigin-RevId: 622487044
This commit is contained in:
parent
e8b86cd81d
commit
3b5980fd73
@ -487,20 +487,9 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, arg_names, ()) = _infer_params(new_jit_info, args, kwargs)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
try:
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
in_layouts = _resolve_in_layouts(
|
||||
args_flat, params['in_layouts'], in_shardings,
|
||||
params['jaxpr'].in_avals)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
in_layouts, params['out_layouts'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'],
|
||||
lowering_parameters=lowering_parameters)
|
||||
lowering = _resolve_and_lower(
|
||||
args_flat, **params, lowering_parameters=lowering_parameters)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
@ -1490,20 +1479,33 @@ def _resolve_in_shardings(
|
||||
return tuple(resolved_in_shardings)
|
||||
|
||||
|
||||
def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
lowering_parameters):
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
|
||||
jaxpr.in_avals)
|
||||
lowered = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return lowered
|
||||
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals)
|
||||
compiled = _resolve_and_lower(
|
||||
args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings,
|
||||
in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
|
||||
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
lowering_parameters=mlir.LoweringParameters()).compile()
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||
|
Loading…
x
Reference in New Issue
Block a user