From 0d92d31063e916fb8f7e303ad1075dec011796ac Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Thu, 25 Jul 2024 22:20:25 +0300 Subject: [PATCH] Show elapsed time in nanoseconds --- jax/_src/interpreters/pxla.py | 10 +++++----- jax/_src/pjit.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index ed341edde..6f8284ae5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -674,7 +674,7 @@ def stage_parallel_callable( fun = orig_fun with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", + "Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) @@ -821,7 +821,7 @@ def lower_parallel_callable( unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( - "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", + "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec", fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, @@ -1051,7 +1051,7 @@ class UnloadedPmapExecutable: out_shardings = _get_pmap_sharding(local_device_assignment, out_specs) with dispatch.log_elapsed_time( - "Finished XLA compilation of {fun_name} in {elapsed_time} sec", + "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): compiled = compiler.compile_or_get_cached( pci.backend, hlo, device_assignment, compile_options, @@ -1935,7 +1935,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( - "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", + "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec", fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, @@ -2616,7 +2616,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, dev, pmap_nreps, compiler_options) with dispatch.log_elapsed_time( - "Finished XLA compilation of {fun_name} in {elapsed_time} sec", + "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 169f98433..d7173c049 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1266,7 +1266,7 @@ def _create_pjit_jaxpr( list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: del ignored_inline # just for explain_cache_miss with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec", + "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: