From 8e1ad734bc685cea7f31e6265fe70c1b51c35669 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 15 May 2023 08:07:31 -0700 Subject: [PATCH] Log the time it takes to lower from jaxpr to stableHLO PiperOrigin-RevId: 532115098 --- jax/_src/dispatch.py | 3 + jax/_src/interpreters/pxla.py | 107 +++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e5f1ea072..4846d59f0 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -61,6 +61,7 @@ from jax._src.sharding_impls import ( JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" +JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration" BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" FLAGS = flags.FLAGS @@ -258,6 +259,8 @@ def is_single_device_sharding(sharding) -> bool: return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) +# TODO(yashkatariya): This API takes in a string which means that string is +# created even if it is not going to be logged. @contextlib.contextmanager def log_elapsed_time(fmt: str, event: Optional[str] = None): if _on_exit: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index dd06b7068..200e95123 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -876,21 +876,24 @@ def lower_parallel_callable( raise ValueError("Ordered effects not supported in `pmap`.") unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - module_name, - closed_jaxpr, - ordered_effects, - backend, - lowering_platform or backend.platform, - sharding_impls.ReplicaAxisContext(axis_env), - name_stack, - donated_invars, - replicated_args=replicated_args, - arg_shardings=None, - result_shardings=None, - arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, - result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, - num_replicas=replicas.num_global_replicas) + with dispatch.log_elapsed_time( + f"Finished jaxpr to MLIR module conversion {name_stack} " + "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + lowering_result = mlir.lower_jaxpr_to_module( + module_name, + closed_jaxpr, + ordered_effects, + backend, + lowering_platform or backend.platform, + sharding_impls.ReplicaAxisContext(axis_env), + name_stack, + donated_invars, + replicated_args=replicated_args, + arg_shardings=None, + result_shardings=None, + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, + num_replicas=replicas.num_global_replicas) return PmapComputation(lowering_result.module, pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, @@ -1916,23 +1919,27 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, in closed_jaxpr.effects): raise ValueError("Ordered effects are not supported for more than 1 device.") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - module_name, - closed_jaxpr, - ordered_effects, - backend, - # Optionally, override the lowering platform - lowering_platform or backend.platform, - axis_ctx, - name_stack, - donated_invars, - replicated_args=replicated_args, - arg_shardings=in_mlir_shardings, - result_shardings=out_mlir_shardings, - arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, - result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, - num_replicas=nreps, - num_partitions=num_partitions) + + with dispatch.log_elapsed_time( + f"Finished jaxpr to MLIR module conversion {name_stack} " + "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + lowering_result = mlir.lower_jaxpr_to_module( + module_name, + closed_jaxpr, + ordered_effects, + backend, + # Optionally, override the lowering platform + lowering_platform or backend.platform, + axis_ctx, + name_stack, + donated_invars, + replicated_args=replicated_args, + arg_shardings=in_mlir_shardings, + result_shardings=out_mlir_shardings, + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, + num_replicas=nreps, + num_partitions=num_partitions) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -2242,22 +2249,26 @@ def lower_mesh_computation( closed_jaxpr.effects)) ordered_effects = list(effects.ordered_effects.filter_in( closed_jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - module_name, - closed_jaxpr, - ordered_effects, - backend, - lowering_platform or backend.platform, - axis_ctx, - name_stack, - donated_invars, - replicated_args=replicated_args, - arg_shardings=in_partitions, - result_shardings=out_partitions, - arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, - result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, - num_replicas=num_replicas, - num_partitions=num_partitions) + with dispatch.log_elapsed_time( + f"Finished jaxpr to MLIR module conversion {name_stack} " + "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + lowering_result = mlir.lower_jaxpr_to_module( + module_name, + closed_jaxpr, + ordered_effects, + backend, + lowering_platform or backend.platform, + axis_ctx, + name_stack, + donated_invars, + replicated_args=replicated_args, + arg_shardings=in_partitions, + result_shardings=out_partitions, + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, + num_replicas=num_replicas, + num_partitions=num_partitions) + return MeshComputation( str(name_stack), lowering_result.module,