diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4846d59f0..044b5c640 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -259,10 +259,8 @@ def is_single_device_sharding(sharding) -> bool: return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) -# TODO(yashkatariya): This API takes in a string which means that string is -# created even if it is not going to be logged. @contextlib.contextmanager -def log_elapsed_time(fmt: str, event: Optional[str] = None): +def log_elapsed_time(fmt: str, fun_name: str, event: Optional[str] = None): if _on_exit: yield else: @@ -271,7 +269,8 @@ def log_elapsed_time(fmt: str, event: Optional[str] = None): yield elapsed_time = time.time() - start_time if logger.isEnabledFor(log_priority): - logger.log(logging.WARNING, fmt.format(elapsed_time=elapsed_time)) + logger.log(logging.WARNING, fmt.format( + fun_name=fun_name, elapsed_time=elapsed_time)) if event is not None: record_event_duration_secs(event, elapsed_time) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 200e95123..0be652994 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -742,9 +742,9 @@ def stage_parallel_callable( for axis, aval in safe_zip(pci.in_axes, pci.avals)) with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore - with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for pmap in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): + with dispatch.log_elapsed_time( + "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", + fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info) @@ -877,8 +877,8 @@ def lower_parallel_callable( unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( - f"Finished jaxpr to MLIR module conversion {name_stack} " - "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", + fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -1081,8 +1081,8 @@ class UnloadedPmapExecutable: ordered_effects, jaxpr_debug_info) with dispatch.log_elapsed_time( - f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec", - event=dispatch.BACKEND_COMPILE_EVENT): + "Finished XLA compilation of {fun_name} in {elapsed_time} sec", + fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): compiled = dispatch.compile_or_get_cached( pci.backend, hlo, device_assignment, compile_options, host_callbacks) @@ -1820,8 +1820,8 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name, if isinstance(fun_or_jaxpr, lu.WrappedFun): with dispatch.log_elapsed_time( - f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): + "Finished tracing + transforming {fun_name} in {elapsed_time} sec", + fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun_or_jaxpr, global_in_avals) else: @@ -1921,8 +1921,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( - f"Finished jaxpr to MLIR module conversion {name_stack} " - "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", + fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -2191,8 +2191,8 @@ def lower_mesh_computation( with core.extend_axis_env_nd(mesh.shape.items()): if isinstance(fun_or_jaxpr, lu.WrappedFun): with dispatch.log_elapsed_time( - f"Finished tracing + transforming {name_stack} in " - "{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): + "Finished tracing + transforming {fun_name} in {elapsed_time} sec", + fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final( fun_or_jaxpr, in_jaxpr_avals) else: @@ -2250,8 +2250,8 @@ def lower_mesh_computation( ordered_effects = list(effects.ordered_effects.filter_in( closed_jaxpr.effects)) with dispatch.log_elapsed_time( - f"Finished jaxpr to MLIR module conversion {name_stack} " - "in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", + fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -2556,9 +2556,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, if hasattr(backend, "compile_replicated"): return None, compile_options - with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " - "in {elapsed_time} sec", - event=dispatch.BACKEND_COMPILE_EVENT): + with dispatch.log_elapsed_time( + "Finished XLA compilation of {fun_name} in {elapsed_time} sec", + fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = dispatch.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks) return xla_executable, compile_options diff --git a/jax/_src/maps.py b/jax/_src/maps.py index c3c8fdd9a..64bc0d4d6 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -652,9 +652,9 @@ def make_xmap_callable(fun: lu.WrappedFun, mapped_in_avals = [_delete_aval_axes(aval, in_axes, global_axis_sizes) for aval, in_axes in zip(in_avals, in_axes)] with core.extend_axis_env_nd(global_axis_sizes.items()): - with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for xmap in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): + with dispatch.log_elapsed_time( + "Finished tracing + transforming {fun_name} for xmap in {elapsed_time} sec", + fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals) out_axes = out_axes_thunk() _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8ea45f7eb..320a54132 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -886,9 +886,9 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, @lu.cache def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths): - with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for pjit in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): + with dispatch.log_elapsed_time( + "Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec", + fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) if config.jax_dynamic_shapes: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(