[Micro-optimization] Only log the avals and shardings if logging is enabled for that level.

PiperOrigin-RevId: 524845969
This commit is contained in:
Yash Katariya 2023-04-17 07:52:56 -07:00 committed by jax authors
parent 8ce19eea4f
commit febd339742
2 changed files with 18 additions and 14 deletions

View File

@ -265,7 +265,8 @@ def log_elapsed_time(fmt: str, event: Optional[str] = None):
start_time = time.time()
yield
elapsed_time = time.time() - start_time
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
if logger.isEnabledFor(log_priority):
logger.log(logging.WARNING, fmt.format(elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)

View File

@ -855,10 +855,11 @@ def lower_parallel_callable(
f"{replicas.jaxpr_replicas}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas)
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas)
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
@ -1876,10 +1877,11 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
device_assignment = da_object.device_assignment
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
@ -2151,11 +2153,12 @@ def lower_mesh_computation(
global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
in_shardings)
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
in_shardings)
# 1. Trace to jaxpr and preprocess/verify it
if spmd_lowering: