mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Disable XLA detailed logging and dumping for small computations.
This significantly reduces the amount of logging from XLA on TPU. PiperOrigin-RevId: 565148809
This commit is contained in:
parent
eeb32a7d1f
commit
729752b32b
@ -52,6 +52,16 @@ _DUMP_IR_TO = jax_config.DEFINE_string(
|
||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||
"will not dump IR.")
|
||||
|
||||
_COMPILER_DETAILED_LOGGING_MIN_OPS = jax_config.DEFINE_integer(
|
||||
"jax_compiler_detailed_logging_min_ops",
|
||||
jax_config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
|
||||
help=(
|
||||
'How big should a module be in MLIR operations before JAX enables '
|
||||
'detailed compiler logging? The intent of this flag is to suppress '
|
||||
'detailed logging for small/uninteresting computations.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -71,6 +81,25 @@ def get_latest_profile_version() -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def _walk_operations(op, k):
|
||||
k -= 1
|
||||
if k < 0:
|
||||
return k
|
||||
for region in op.regions:
|
||||
for block in region:
|
||||
for child_op in block:
|
||||
k = _walk_operations(child_op, k)
|
||||
if k < 0:
|
||||
return k
|
||||
return k
|
||||
|
||||
|
||||
def use_detailed_logging(module: ir.Module) -> bool:
|
||||
"""Returns 'true' if detailed logging should be enabled for 'module'."""
|
||||
bound = _COMPILER_DETAILED_LOGGING_MIN_OPS.value
|
||||
return _walk_operations(module.operation, bound) < 0
|
||||
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
@ -81,6 +110,7 @@ def get_compile_options(
|
||||
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
|
||||
env_options_overrides: dict[str, str] | None = None,
|
||||
fdo_profile: bytes | None = None,
|
||||
detailed_logging: bool = True,
|
||||
) -> xc.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
@ -101,7 +131,9 @@ def get_compile_options(
|
||||
auto_spmd_partitioning search space.
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
fdo_profile: Optional profile for feedback-directed optimization passed to
|
||||
XLA.
|
||||
XLA.
|
||||
detailed_logging: Is this an "interesting" computation about which XLA
|
||||
would be wise to log compilation information?
|
||||
"""
|
||||
compile_options = xc.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
@ -178,6 +210,7 @@ def get_compile_options(
|
||||
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
||||
"XLA-AutoFDO profile version is 0; this should not happen")
|
||||
|
||||
debug_options.xla_detailed_logging_and_dumping = detailed_logging
|
||||
return compile_options
|
||||
|
||||
|
||||
|
@ -911,6 +911,7 @@ class UnloadedPmapExecutable:
|
||||
device_assignment=device_assignment,
|
||||
use_spmd_partitioning=False,
|
||||
env_options_overrides=compiler_options,
|
||||
detailed_logging=compiler.use_detailed_logging(hlo)
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
|
||||
@ -2502,6 +2503,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
env_options_overrides=compiler_options,
|
||||
fdo_profile=fdo_profile,
|
||||
detailed_logging=compiler.use_detailed_logging(computation)
|
||||
)
|
||||
|
||||
opts = compile_options.executable_build_options
|
||||
|
Loading…
x
Reference in New Issue
Block a user