mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Show elapsed time in nanoseconds
This commit is contained in:
parent
f17d0f382a
commit
0d92d31063
@ -674,7 +674,7 @@ def stage_parallel_callable(
|
||||
fun = orig_fun
|
||||
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
@ -821,7 +821,7 @@ def lower_parallel_callable(
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
@ -1051,7 +1051,7 @@ class UnloadedPmapExecutable:
|
||||
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
|
||||
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
compiled = compiler.compile_or_get_cached(
|
||||
pci.backend, hlo, device_assignment, compile_options,
|
||||
@ -1935,7 +1935,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
f"more than 1 device: {unsupported_effects}")
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
@ -2616,7 +2616,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
dev, pmap_nreps, compiler_options)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
|
||||
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = compiler.compile_or_get_cached(
|
||||
backend, computation, dev, compile_options, host_callbacks,
|
||||
|
@ -1266,7 +1266,7 @@ def _create_pjit_jaxpr(
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
del ignored_inline # just for explain_cache_miss
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
||||
if config.dynamic_shapes.value:
|
||||
|
Loading…
x
Reference in New Issue
Block a user