mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove the f-string evaluation during logging the elapsed time by passing in fun_name to log_elapsed_time
PiperOrigin-RevId: 532132574
This commit is contained in:
parent
15caafd937
commit
b196ad2e8c
@ -259,10 +259,8 @@ def is_single_device_sharding(sharding) -> bool:
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
|
||||
# TODO(yashkatariya): This API takes in a string which means that string is
|
||||
# created even if it is not going to be logged.
|
||||
@contextlib.contextmanager
|
||||
def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
def log_elapsed_time(fmt: str, fun_name: str, event: Optional[str] = None):
|
||||
if _on_exit:
|
||||
yield
|
||||
else:
|
||||
@ -271,7 +269,8 @@ def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(logging.WARNING, fmt.format(elapsed_time=elapsed_time))
|
||||
logger.log(logging.WARNING, fmt.format(
|
||||
fun_name=fun_name, elapsed_time=elapsed_time))
|
||||
if event is not None:
|
||||
record_event_duration_secs(event, elapsed_time)
|
||||
|
||||
|
@ -742,9 +742,9 @@ def stage_parallel_callable(
|
||||
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
||||
|
||||
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pmap in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
|
||||
@ -877,8 +877,8 @@ def lower_parallel_callable(
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -1081,8 +1081,8 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects, jaxpr_debug_info)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
compiled = dispatch.compile_or_get_cached(
|
||||
pci.backend, hlo, device_assignment, compile_options,
|
||||
host_callbacks)
|
||||
@ -1820,8 +1820,8 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
||||
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
|
||||
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun_or_jaxpr, global_in_avals)
|
||||
else:
|
||||
@ -1921,8 +1921,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2191,8 +2191,8 @@ def lower_mesh_computation(
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished tracing + transforming {name_stack} in "
|
||||
"{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
|
||||
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun_or_jaxpr, in_jaxpr_avals)
|
||||
else:
|
||||
@ -2250,8 +2250,8 @@ def lower_mesh_computation(
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2556,9 +2556,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return None, compile_options
|
||||
|
||||
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = dispatch.compile_or_get_cached(
|
||||
backend, computation, dev, compile_options, host_callbacks)
|
||||
return xla_executable, compile_options
|
||||
|
@ -652,9 +652,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
mapped_in_avals = [_delete_aval_axes(aval, in_axes, global_axis_sizes)
|
||||
for aval, in_axes in zip(in_avals, in_axes)]
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for xmap in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for xmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
|
||||
out_axes = out_axes_thunk()
|
||||
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
|
@ -886,9 +886,9 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
|
||||
@lu.cache
|
||||
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pjit in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
||||
if config.jax_dynamic_shapes:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
|
Loading…
x
Reference in New Issue
Block a user