mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Log the time it takes to lower from jaxpr to stableHLO
PiperOrigin-RevId: 532115098
This commit is contained in:
parent
843106b73c
commit
8e1ad734bc
@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -258,6 +259,8 @@ def is_single_device_sharding(sharding) -> bool:
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
|
||||
# TODO(yashkatariya): This API takes in a string which means that string is
|
||||
# created even if it is not going to be logged.
|
||||
@contextlib.contextmanager
|
||||
def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
if _on_exit:
|
||||
|
@ -876,6 +876,9 @@ def lower_parallel_callable(
|
||||
raise ValueError("Ordered effects not supported in `pmap`.")
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -1916,6 +1919,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2242,6 +2249,9 @@ def lower_mesh_computation(
|
||||
closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2258,6 +2268,7 @@ def lower_mesh_computation(
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions)
|
||||
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
lowering_result.module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user