Pass the jaxpr from pjit since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.

PiperOrigin-RevId: 513268085
This commit is contained in:
Yash Katariya 2023-03-01 10:04:59 -08:00 committed by jax authors
parent ed491b3056
commit 1ee750e795
2 changed files with 38 additions and 30 deletions

View File

@ -2867,7 +2867,7 @@ def _get_and_check_device_assignment(
@profiler.annotate_function
def lower_sharding_computation(
fun: lu.WrappedFun,
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]],
@ -2889,11 +2889,19 @@ def lower_sharding_computation(
# 1. Trace to jaxpr and preprocess/verify it
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
fun_or_jaxpr, global_in_avals,
debug_info=pe.debug_info_final(fun_or_jaxpr, api_name))
else:
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
kept_outputs = [True] * len(global_out_avals)
if _is_unspecified(out_shardings):
@ -2927,10 +2935,9 @@ def lower_sharding_computation(
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for with global shapes and types %s. "
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
global_in_avals, in_shardings)
fun_name, global_in_avals, in_shardings)
if keep_unused:
kept_var_idx = set(range(len(global_in_avals)))
@ -3089,7 +3096,7 @@ def lower_sharding_computation(
@profiler.annotate_function
def lower_mesh_computation(
fun: lu.WrappedFun,
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
mesh: Mesh,
@ -3114,10 +3121,9 @@ def lower_mesh_computation(
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
"Compiling %s for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
tuple(global_axis_sizes.items()), global_in_avals,
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
in_shardings)
# 1. Trace to jaxpr and preprocess/verify it
@ -3134,10 +3140,11 @@ def lower_mesh_computation(
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_shardings)
assert not auto_spmd_lowering
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
# is why `.spec` can be accessed.
fun = tiling_transform(
fun, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
fun_or_jaxpr = tiling_transform(
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
in_jaxpr_avals = global_in_avals
else:
@ -3148,11 +3155,20 @@ def lower_mesh_computation(
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
for aval, i in safe_zip(global_in_avals, in_shardings)]
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(
f"Finished tracing + transforming {name_stack} in "
"{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, in_jaxpr_avals)
else:
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals

View File

@ -1393,10 +1393,6 @@ def _pjit_lower_cached(
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
if resource_env is not None:
mesh = resource_env.physical_mesh
api_name = 'pjit'
@ -1427,18 +1423,14 @@ def _pjit_lower_cached(
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if (any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')):
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
return pxla.lower_mesh_computation(
fun, api_name, name, mesh,
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
else:
# Pass `in_is_global` here because this path is taken by both host local
# avals and global avals.
# TODO(yashkatariya): Don't set committed to True always. Infer that from
# the arguments just like dispatch.py in `sharded_lowering`.
return pxla.lower_sharding_computation(
fun, api_name, name, in_shardings, out_shardings, donated_invars,
jaxpr, api_name, name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
always_lower=always_lower,
devices_from_context=(