From febd339742846c97c561bb7691072a0d774586e5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 17 Apr 2023 07:52:56 -0700 Subject: [PATCH] [Micro-optimization] Only log the avals and shardings if logging is enabled for that level. PiperOrigin-RevId: 524845969 --- jax/_src/dispatch.py | 3 ++- jax/_src/interpreters/pxla.py | 29 ++++++++++++++++------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6e691d855..4f592b15b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -265,7 +265,8 @@ def log_elapsed_time(fmt: str, event: Optional[str] = None): start_time = time.time() yield elapsed_time = time.time() - start_time - logger.log(log_priority, fmt.format(elapsed_time=elapsed_time)) + if logger.isEnabledFor(log_priority): + logger.log(logging.WARNING, fmt.format(elapsed_time=elapsed_time)) if event is not None: record_event_duration_secs(event, elapsed_time) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index db16fd586..3abe4a6b9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -855,10 +855,11 @@ def lower_parallel_callable( f"{replicas.jaxpr_replicas}") log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG - logger.log(log_priority, - "Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)", - fun.__name__, id(fun), - shards.num_global_shards, avals, replicas.num_global_replicas) + if logger.isEnabledFor(log_priority): + logger.log(log_priority, + "Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)", + fun.__name__, id(fun), + shards.num_global_shards, avals, replicas.num_global_replicas) axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) @@ -1876,10 +1877,11 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, device_assignment = da_object.device_assignment log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG - logger.log(log_priority, - "Compiling %s for with global shapes and types %s. " - "Argument mapping: %s.", - fun_name, global_in_avals, in_shardings) + if logger.isEnabledFor(log_priority): + logger.log(log_priority, + "Compiling %s for with global shapes and types %s. " + "Argument mapping: %s.", + fun_name, global_in_avals, in_shardings) # Look at the number of replcas present in the jaxpr. In # lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is @@ -2151,11 +2153,12 @@ def lower_mesh_computation( global_axis_sizes = mesh.shape log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG - logger.log(log_priority, - "Compiling %s for %s mesh with global shapes and types %s. " - "Argument mapping: %s.", - fun_name, tuple(global_axis_sizes.items()), global_in_avals, - in_shardings) + if logger.isEnabledFor(log_priority): + logger.log(log_priority, + "Compiling %s for %s mesh with global shapes and types %s. " + "Argument mapping: %s.", + fun_name, tuple(global_axis_sizes.items()), global_in_avals, + in_shardings) # 1. Trace to jaxpr and preprocess/verify it if spmd_lowering: