From fc8058a17d894705ec32969cb99526f9f9984296 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 20 Nov 2023 15:51:27 -0800 Subject: [PATCH] Restrict retrieving XLA-AutoFDO profile version to TPU workloads. XLA-AutoFDO is supported only for TPUs, so requesting the latest profile version for non-TPU workloads is unnecessary and can delay the completion of initialization. Testing: test workload. PiperOrigin-RevId: 584148686 --- jax/_src/compiler.py | 32 +++++++++++++++++++++----------- jax/_src/interpreters/pxla.py | 6 ++++-- tests/xla_bridge_test.py | 16 ++++++++++------ 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index fe51aec8c..ef7932108 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -61,6 +61,9 @@ _COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer( ), ) +# The special XLA-AutoFDO profile version that indicates that a profile is not +# available and retrieval should not be attempted. +_NO_PROFILE_DONT_RETRIEVE = -1 traceback_util.register_exclusion(__file__) @@ -76,7 +79,9 @@ _cache_used: bool = False # Will be monkeypatched with the function that gets the XLA-AutoFDO profile # version. The default (-1) takes care of errors. -def get_latest_profile_version() -> int: +# TODO(b/289098047): consider refactoring this interface. +def get_latest_profile_version(backend: xc.Client) -> int: + del backend return -1 @@ -110,6 +115,7 @@ def get_compile_options( env_options_overrides: dict[str, str] | None = None, fdo_profile: bytes | None = None, detailed_logging: bool = True, + backend: xc.Client | None = None, ) -> xc.CompileOptions: """Returns the compile options to use, as derived from flag values. @@ -133,6 +139,7 @@ def get_compile_options( XLA. detailed_logging: Is this an "interesting" computation about which XLA would be wise to log compilation information? + backend: the client, if available. """ compile_options = xc.CompileOptions() compile_options.num_replicas = num_replicas @@ -197,17 +204,20 @@ def get_compile_options( "using JAX XLA profile version %d from flag", jax_xla_profile_version) else: - fdo_profile_version = get_latest_profile_version() - if fdo_profile_version != 0: - compile_options.profile_version = fdo_profile_version - logger.debug("get_compile_options XLA-AutoFDO profile: " + - "using XLA-AutoFDO profile version %d", - fdo_profile_version) + compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE + if backend is None: + logging.info("get_compile_options: no backend supplied; " + "disabling XLA-AutoFDO profile") else: - no_profile_dont_retrieve = -1 - compile_options.profile_version = no_profile_dont_retrieve - logger.error("get_compile_options XLA-AutoFDO profile: " + - "XLA-AutoFDO profile version is 0; this should not happen") + fdo_profile_version = get_latest_profile_version(backend) + if fdo_profile_version != 0: + compile_options.profile_version = fdo_profile_version + logger.debug("get_compile_options XLA-AutoFDO profile: " + + "using XLA-AutoFDO profile version %d", + fdo_profile_version) + else: + logger.error("get_compile_options XLA-AutoFDO profile: " + + "XLA-AutoFDO profile version is 0; this should not happen") debug_options.xla_detailed_logging = detailed_logging diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 58f6abcc2..35efac070 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -915,7 +915,8 @@ class UnloadedPmapExecutable: device_assignment=device_assignment, use_spmd_partitioning=False, env_options_overrides=compiler_options, - detailed_logging=compiler.use_detailed_logging(hlo) + detailed_logging=compiler.use_detailed_logging(hlo), + backend=pci.backend, ) compile_options.parameter_is_tupled_arguments = tuple_args @@ -2505,7 +2506,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, fdo_profile=fdo_profile, - detailed_logging=compiler.use_detailed_logging(computation) + detailed_logging=compiler.use_detailed_logging(computation), + backend=backend, ) opts = compile_options.executable_build_options diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index ac41c9a99..b5522e2e5 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -59,15 +59,19 @@ class XlaBridgeTest(jtu.JaxTestCase): ) def test_autofdo_profile(self): + + class _DummyBackend: + platform: str = "tpu" + # --jax_xla_profile_version takes precedence. jax_flag_profile = 1 another_profile = 2 with config.jax_xla_profile_version(jax_flag_profile): with mock.patch.object(compiler, "get_latest_profile_version", - side_effect=lambda: another_profile): + side_effect=lambda _: another_profile): self.assertEqual( compiler.get_compile_options( - num_replicas=3, num_partitions=4 + num_replicas=3, num_partitions=4, backend=_DummyBackend(), ).profile_version, jax_flag_profile, ) @@ -76,10 +80,10 @@ class XlaBridgeTest(jtu.JaxTestCase): # returns if --jax_xla_profile_version is not set. profile_version = 1 with mock.patch.object(compiler, "get_latest_profile_version", - side_effect=lambda: profile_version): + side_effect=lambda _: profile_version): self.assertEqual( compiler.get_compile_options( - num_replicas=3, num_partitions=4 + num_replicas=3, num_partitions=4, backend=_DummyBackend(), ).profile_version, profile_version, ) @@ -90,10 +94,10 @@ class XlaBridgeTest(jtu.JaxTestCase): error_return = 0 no_profile_dont_retrieve = -1 with mock.patch.object(compiler, "get_latest_profile_version", - side_effect=lambda: error_return): + side_effect=lambda _: error_return): self.assertEqual( compiler.get_compile_options( - num_replicas=3, num_partitions=4 + num_replicas=3, num_partitions=4, backend=_DummyBackend(), ).profile_version, no_profile_dont_retrieve, )