Disable XLA detailed logging and dumping for small computations.

This significantly reduces the amount of logging from XLA on TPU.

PiperOrigin-RevId: 565148809
This commit is contained in:
Peter Hawkins 2023-09-13 13:44:21 -07:00 committed by jax authors
parent eeb32a7d1f
commit 729752b32b
2 changed files with 36 additions and 1 deletions

View File

@ -52,6 +52,16 @@ _DUMP_IR_TO = jax_config.DEFINE_string(
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")
_COMPILER_DETAILED_LOGGING_MIN_OPS = jax_config.DEFINE_integer(
"jax_compiler_detailed_logging_min_ops",
jax_config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
help=(
'How big should a module be in MLIR operations before JAX enables '
'detailed compiler logging? The intent of this flag is to suppress '
'detailed logging for small/uninteresting computations.'
),
)
traceback_util.register_exclusion(__file__)
@ -71,6 +81,25 @@ def get_latest_profile_version() -> int:
return -1
def _walk_operations(op, k):
k -= 1
if k < 0:
return k
for region in op.regions:
for block in region:
for child_op in block:
k = _walk_operations(child_op, k)
if k < 0:
return k
return k
def use_detailed_logging(module: ir.Module) -> bool:
"""Returns 'true' if detailed logging should be enabled for 'module'."""
bound = _COMPILER_DETAILED_LOGGING_MIN_OPS.value
return _walk_operations(module.operation, bound) < 0
def get_compile_options(
num_replicas: int,
num_partitions: int,
@ -81,6 +110,7 @@ def get_compile_options(
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
env_options_overrides: dict[str, str] | None = None,
fdo_profile: bytes | None = None,
detailed_logging: bool = True,
) -> xc.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
@ -101,7 +131,9 @@ def get_compile_options(
auto_spmd_partitioning search space.
env_options_overrides: dict of additional options parsed by the compiler
fdo_profile: Optional profile for feedback-directed optimization passed to
XLA.
XLA.
detailed_logging: Is this an "interesting" computation about which XLA
would be wise to log compilation information?
"""
compile_options = xc.CompileOptions()
compile_options.num_replicas = num_replicas
@ -178,6 +210,7 @@ def get_compile_options(
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")
debug_options.xla_detailed_logging_and_dumping = detailed_logging
return compile_options

View File

@ -911,6 +911,7 @@ class UnloadedPmapExecutable:
device_assignment=device_assignment,
use_spmd_partitioning=False,
env_options_overrides=compiler_options,
detailed_logging=compiler.use_detailed_logging(hlo)
)
compile_options.parameter_is_tupled_arguments = tuple_args
@ -2502,6 +2503,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
fdo_profile=fdo_profile,
detailed_logging=compiler.use_detailed_logging(computation)
)
opts = compile_options.executable_build_options