mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
XLA-AutoFDO is supported only for TPUs, so requesting the latest profile version for non-TPU workloads is unnecessary and can delay the completion of initialization. Testing: test workload. PiperOrigin-RevId: 584148686
This commit is contained in:
parent
03cae62c78
commit
fc8058a17d
@ -61,6 +61,9 @@ _COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
|
||||
),
|
||||
)
|
||||
|
||||
# The special XLA-AutoFDO profile version that indicates that a profile is not
|
||||
# available and retrieval should not be attempted.
|
||||
_NO_PROFILE_DONT_RETRIEVE = -1
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -76,7 +79,9 @@ _cache_used: bool = False
|
||||
|
||||
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
|
||||
# version. The default (-1) takes care of errors.
|
||||
def get_latest_profile_version() -> int:
|
||||
# TODO(b/289098047): consider refactoring this interface.
|
||||
def get_latest_profile_version(backend: xc.Client) -> int:
|
||||
del backend
|
||||
return -1
|
||||
|
||||
|
||||
@ -110,6 +115,7 @@ def get_compile_options(
|
||||
env_options_overrides: dict[str, str] | None = None,
|
||||
fdo_profile: bytes | None = None,
|
||||
detailed_logging: bool = True,
|
||||
backend: xc.Client | None = None,
|
||||
) -> xc.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
@ -133,6 +139,7 @@ def get_compile_options(
|
||||
XLA.
|
||||
detailed_logging: Is this an "interesting" computation about which XLA
|
||||
would be wise to log compilation information?
|
||||
backend: the client, if available.
|
||||
"""
|
||||
compile_options = xc.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
@ -197,17 +204,20 @@ def get_compile_options(
|
||||
"using JAX XLA profile version %d from flag",
|
||||
jax_xla_profile_version)
|
||||
else:
|
||||
fdo_profile_version = get_latest_profile_version()
|
||||
if fdo_profile_version != 0:
|
||||
compile_options.profile_version = fdo_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using XLA-AutoFDO profile version %d",
|
||||
fdo_profile_version)
|
||||
compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE
|
||||
if backend is None:
|
||||
logging.info("get_compile_options: no backend supplied; "
|
||||
"disabling XLA-AutoFDO profile")
|
||||
else:
|
||||
no_profile_dont_retrieve = -1
|
||||
compile_options.profile_version = no_profile_dont_retrieve
|
||||
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
||||
"XLA-AutoFDO profile version is 0; this should not happen")
|
||||
fdo_profile_version = get_latest_profile_version(backend)
|
||||
if fdo_profile_version != 0:
|
||||
compile_options.profile_version = fdo_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using XLA-AutoFDO profile version %d",
|
||||
fdo_profile_version)
|
||||
else:
|
||||
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
||||
"XLA-AutoFDO profile version is 0; this should not happen")
|
||||
|
||||
debug_options.xla_detailed_logging = detailed_logging
|
||||
|
||||
|
@ -915,7 +915,8 @@ class UnloadedPmapExecutable:
|
||||
device_assignment=device_assignment,
|
||||
use_spmd_partitioning=False,
|
||||
env_options_overrides=compiler_options,
|
||||
detailed_logging=compiler.use_detailed_logging(hlo)
|
||||
detailed_logging=compiler.use_detailed_logging(hlo),
|
||||
backend=pci.backend,
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
|
||||
@ -2505,7 +2506,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
env_options_overrides=compiler_options,
|
||||
fdo_profile=fdo_profile,
|
||||
detailed_logging=compiler.use_detailed_logging(computation)
|
||||
detailed_logging=compiler.use_detailed_logging(computation),
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
opts = compile_options.executable_build_options
|
||||
|
@ -59,15 +59,19 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_autofdo_profile(self):
|
||||
|
||||
class _DummyBackend:
|
||||
platform: str = "tpu"
|
||||
|
||||
# --jax_xla_profile_version takes precedence.
|
||||
jax_flag_profile = 1
|
||||
another_profile = 2
|
||||
with config.jax_xla_profile_version(jax_flag_profile):
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: another_profile):
|
||||
side_effect=lambda _: another_profile):
|
||||
self.assertEqual(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
|
||||
).profile_version,
|
||||
jax_flag_profile,
|
||||
)
|
||||
@ -76,10 +80,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
# returns if --jax_xla_profile_version is not set.
|
||||
profile_version = 1
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: profile_version):
|
||||
side_effect=lambda _: profile_version):
|
||||
self.assertEqual(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
|
||||
).profile_version,
|
||||
profile_version,
|
||||
)
|
||||
@ -90,10 +94,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
error_return = 0
|
||||
no_profile_dont_retrieve = -1
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: error_return):
|
||||
side_effect=lambda _: error_return):
|
||||
self.assertEqual(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
|
||||
).profile_version,
|
||||
no_profile_dont_retrieve,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user