mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Pass the jaxpr
from pjit
since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.
PiperOrigin-RevId: 513268085
This commit is contained in:
parent
ed491b3056
commit
1ee750e795
@ -2867,7 +2867,7 @@ def _get_and_check_device_assignment(
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
fun: lu.WrappedFun,
|
||||
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]],
|
||||
@ -2889,11 +2889,19 @@ def lower_sharding_computation(
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
|
||||
fun_or_jaxpr, global_in_avals,
|
||||
debug_info=pe.debug_info_final(fun_or_jaxpr, api_name))
|
||||
else:
|
||||
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
||||
jaxpr = fun_or_jaxpr.jaxpr
|
||||
global_out_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
kept_outputs = [True] * len(global_out_avals)
|
||||
|
||||
if _is_unspecified(out_shardings):
|
||||
@ -2927,10 +2935,9 @@ def lower_sharding_computation(
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for with global shapes and types %s. "
|
||||
"Compiling %s for with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
||||
global_in_avals, in_shardings)
|
||||
fun_name, global_in_avals, in_shardings)
|
||||
|
||||
if keep_unused:
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
@ -3089,7 +3096,7 @@ def lower_sharding_computation(
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_mesh_computation(
|
||||
fun: lu.WrappedFun,
|
||||
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
mesh: Mesh,
|
||||
@ -3114,10 +3121,9 @@ def lower_mesh_computation(
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
|
||||
"Compiling %s for %s mesh with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
||||
tuple(global_axis_sizes.items()), global_in_avals,
|
||||
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
|
||||
in_shardings)
|
||||
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
@ -3134,10 +3140,11 @@ def lower_mesh_computation(
|
||||
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
||||
assert not callable(out_shardings)
|
||||
assert not auto_spmd_lowering
|
||||
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
|
||||
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
|
||||
# is why `.spec` can be accessed.
|
||||
fun = tiling_transform(
|
||||
fun, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
|
||||
fun_or_jaxpr = tiling_transform(
|
||||
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
|
||||
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
|
||||
in_jaxpr_avals = global_in_avals
|
||||
else:
|
||||
@ -3148,11 +3155,20 @@ def lower_mesh_computation(
|
||||
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
||||
in_jaxpr_avals = in_tiled_avals
|
||||
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished tracing + transforming {name_stack} in "
|
||||
"{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun_or_jaxpr, in_jaxpr_avals)
|
||||
else:
|
||||
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
||||
jaxpr = fun_or_jaxpr.jaxpr
|
||||
out_jaxpr_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
assert len(out_shardings) == len(out_jaxpr_avals)
|
||||
if spmd_lowering:
|
||||
global_out_avals = out_jaxpr_avals
|
||||
|
@ -1393,10 +1393,6 @@ def _pjit_lower_cached(
|
||||
if resource_env is not None:
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
|
||||
f = core.jaxpr_as_fun(jaxpr)
|
||||
f.__name__ = name
|
||||
fun = lu.wrap_init(f)
|
||||
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
api_name = 'pjit'
|
||||
@ -1427,18 +1423,14 @@ def _pjit_lower_cached(
|
||||
|
||||
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
|
||||
# because `xmap` only supports SPMDAxisContext right now.
|
||||
if (any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')):
|
||||
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
|
||||
return pxla.lower_mesh_computation(
|
||||
fun, api_name, name, mesh,
|
||||
jaxpr, api_name, name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
|
||||
else:
|
||||
# Pass `in_is_global` here because this path is taken by both host local
|
||||
# avals and global avals.
|
||||
# TODO(yashkatariya): Don't set committed to True always. Infer that from
|
||||
# the arguments just like dispatch.py in `sharded_lowering`.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, api_name, name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr, api_name, name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
|
||||
always_lower=always_lower,
|
||||
devices_from_context=(
|
||||
|
Loading…
x
Reference in New Issue
Block a user