Restrict retrieving XLA-AutoFDO profile version to TPU workloads.

XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.

Testing: test workload.
PiperOrigin-RevId: 584148686
This commit is contained in:
jax authors 2023-11-20 15:51:27 -08:00
parent 03cae62c78
commit fc8058a17d
3 changed files with 35 additions and 19 deletions

View File

@ -61,6 +61,9 @@ _COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
),
)
# The special XLA-AutoFDO profile version that indicates that a profile is not
# available and retrieval should not be attempted.
_NO_PROFILE_DONT_RETRIEVE = -1
traceback_util.register_exclusion(__file__)
@ -76,7 +79,9 @@ _cache_used: bool = False
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
# version. The default (-1) takes care of errors.
def get_latest_profile_version() -> int:
# TODO(b/289098047): consider refactoring this interface.
def get_latest_profile_version(backend: xc.Client) -> int:
del backend
return -1
@ -110,6 +115,7 @@ def get_compile_options(
env_options_overrides: dict[str, str] | None = None,
fdo_profile: bytes | None = None,
detailed_logging: bool = True,
backend: xc.Client | None = None,
) -> xc.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
@ -133,6 +139,7 @@ def get_compile_options(
XLA.
detailed_logging: Is this an "interesting" computation about which XLA
would be wise to log compilation information?
backend: the client, if available.
"""
compile_options = xc.CompileOptions()
compile_options.num_replicas = num_replicas
@ -197,17 +204,20 @@ def get_compile_options(
"using JAX XLA profile version %d from flag",
jax_xla_profile_version)
else:
fdo_profile_version = get_latest_profile_version()
if fdo_profile_version != 0:
compile_options.profile_version = fdo_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using XLA-AutoFDO profile version %d",
fdo_profile_version)
compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE
if backend is None:
logging.info("get_compile_options: no backend supplied; "
"disabling XLA-AutoFDO profile")
else:
no_profile_dont_retrieve = -1
compile_options.profile_version = no_profile_dont_retrieve
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")
fdo_profile_version = get_latest_profile_version(backend)
if fdo_profile_version != 0:
compile_options.profile_version = fdo_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using XLA-AutoFDO profile version %d",
fdo_profile_version)
else:
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")
debug_options.xla_detailed_logging = detailed_logging

View File

@ -915,7 +915,8 @@ class UnloadedPmapExecutable:
device_assignment=device_assignment,
use_spmd_partitioning=False,
env_options_overrides=compiler_options,
detailed_logging=compiler.use_detailed_logging(hlo)
detailed_logging=compiler.use_detailed_logging(hlo),
backend=pci.backend,
)
compile_options.parameter_is_tupled_arguments = tuple_args
@ -2505,7 +2506,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
fdo_profile=fdo_profile,
detailed_logging=compiler.use_detailed_logging(computation)
detailed_logging=compiler.use_detailed_logging(computation),
backend=backend,
)
opts = compile_options.executable_build_options

View File

@ -59,15 +59,19 @@ class XlaBridgeTest(jtu.JaxTestCase):
)
def test_autofdo_profile(self):
class _DummyBackend:
platform: str = "tpu"
# --jax_xla_profile_version takes precedence.
jax_flag_profile = 1
another_profile = 2
with config.jax_xla_profile_version(jax_flag_profile):
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: another_profile):
side_effect=lambda _: another_profile):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
jax_flag_profile,
)
@ -76,10 +80,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
# returns if --jax_xla_profile_version is not set.
profile_version = 1
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: profile_version):
side_effect=lambda _: profile_version):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
profile_version,
)
@ -90,10 +94,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
error_return = 0
no_profile_dont_retrieve = -1
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: error_return):
side_effect=lambda _: error_return):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
no_profile_dont_retrieve,
)