Show elapsed time in nanoseconds

This commit is contained in:
Ram Rachum 2024-07-25 22:20:25 +03:00
parent f17d0f382a
commit 0d92d31063
2 changed files with 6 additions and 6 deletions

View File

@ -674,7 +674,7 @@ def stage_parallel_callable(
fun = orig_fun
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None):
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
@ -821,7 +821,7 @@ def lower_parallel_callable(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
@ -1051,7 +1051,7 @@ class UnloadedPmapExecutable:
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
"Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
compiled = compiler.compile_or_get_cached(
pci.backend, hlo, device_assignment, compile_options,
@ -1935,7 +1935,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
f"more than 1 device: {unsupported_effects}")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
@ -2616,7 +2616,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
dev, pmap_nreps, compiler_options)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
"Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks,

View File

@ -1266,7 +1266,7 @@ def _create_pjit_jaxpr(
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
del ignored_inline # just for explain_cache_miss
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
if config.dynamic_shapes.value: