Log the time it takes to lower from jaxpr to stableHLO

PiperOrigin-RevId: 532115098
This commit is contained in:
Yash Katariya 2023-05-15 08:07:31 -07:00 committed by jax authors
parent 843106b73c
commit 8e1ad734bc
2 changed files with 62 additions and 48 deletions

View File

@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
FLAGS = flags.FLAGS
@ -258,6 +259,8 @@ def is_single_device_sharding(sharding) -> bool:
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
# TODO(yashkatariya): This API takes in a string which means that string is
# created even if it is not going to be logged.
@contextlib.contextmanager
def log_elapsed_time(fmt: str, event: Optional[str] = None):
if _on_exit:

View File

@ -876,6 +876,9 @@ def lower_parallel_callable(
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -1916,6 +1919,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -2242,6 +2249,9 @@ def lower_mesh_computation(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
with dispatch.log_elapsed_time(
f"Finished jaxpr to MLIR module conversion {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -2258,6 +2268,7 @@ def lower_mesh_computation(
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=num_replicas,
num_partitions=num_partitions)
return MeshComputation(
str(name_stack),
lowering_result.module,