Log the time it takes to lower from jaxpr to stableHLO

PiperOrigin-RevId: 532115098
This commit is contained in:
Yash Katariya 2023-05-15 08:07:31 -07:00 committed by jax authors
parent 843106b73c
commit 8e1ad734bc
2 changed files with 62 additions and 48 deletions

View File

@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
FLAGS = flags.FLAGS
@ -258,6 +259,8 @@ def is_single_device_sharding(sharding) -> bool:
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
# TODO(yashkatariya): This API takes in a string which means that string is
# created even if it is not going to be logged.
@contextlib.contextmanager
def log_elapsed_time(fmt: str, event: Optional[str] = None):
if _on_exit:

View File

@ -876,21 +876,24 @@ def lower_parallel_callable(
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
sharding_impls.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas)
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
sharding_impls.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas)
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
@ -1916,23 +1919,27 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
# Optionally, override the lowering platform
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions)
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
# Optionally, override the lowering platform
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
@ -2242,22 +2249,26 @@ def lower_mesh_computation(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=num_replicas,
num_partitions=num_partitions)
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=num_replicas,
num_partitions=num_partitions)
return MeshComputation(
str(name_stack),
lowering_result.module,