mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Micro-optimization] Only log the avals and shardings if logging is enabled for that level.
PiperOrigin-RevId: 524845969
This commit is contained in:
parent
8ce19eea4f
commit
febd339742
@ -265,7 +265,8 @@ def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
start_time = time.time()
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(logging.WARNING, fmt.format(elapsed_time=elapsed_time))
|
||||
if event is not None:
|
||||
record_event_duration_secs(event, elapsed_time)
|
||||
|
||||
|
@ -855,10 +855,11 @@ def lower_parallel_callable(
|
||||
f"{replicas.jaxpr_replicas}")
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
||||
fun.__name__, id(fun),
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas)
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
||||
fun.__name__, id(fun),
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas)
|
||||
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
@ -1876,10 +1877,11 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
device_assignment = da_object.device_assignment
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s for with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
fun_name, global_in_avals, in_shardings)
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(log_priority,
|
||||
"Compiling %s for with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
fun_name, global_in_avals, in_shardings)
|
||||
|
||||
# Look at the number of replcas present in the jaxpr. In
|
||||
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
|
||||
@ -2151,11 +2153,12 @@ def lower_mesh_computation(
|
||||
global_axis_sizes = mesh.shape
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s for %s mesh with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
|
||||
in_shardings)
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(log_priority,
|
||||
"Compiling %s for %s mesh with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
|
||||
in_shardings)
|
||||
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
if spmd_lowering:
|
||||
|
Loading…
x
Reference in New Issue
Block a user