Change --jax_xla_profile_version definition to config.

Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.

Testing: test workload.
PiperOrigin-RevId: 554942573
This commit is contained in:
jax authors 2023-08-08 14:28:35 -07:00
parent 3e50fea29e
commit d01695c746
2 changed files with 10 additions and 7 deletions

View File

@ -1081,6 +1081,14 @@ config.define_bool_state(
'work under pmap/pjit.')
)
jax_xla_profile_version = config.define_int_state(
name='jax_xla_profile_version',
default=0,
help=('Optional profile version for XLA compilation. This is meaningful '
'only when XLA is configured to support the remote compilation '
'profile feature.')
)
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""

View File

@ -38,7 +38,7 @@ import numpy as np
from jax._src import lib
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.config import bool_env, config
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src import util
@ -78,11 +78,6 @@ _DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
'Optional profile version for XLA compilation. '
'This is meaningful only when XLA is configured to '
'support the remote compilation profile feature.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
@ -175,7 +170,7 @@ def get_compile_options(
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = _XLA_PROFILE_VERSION.value
compile_options.profile_version = config.jax_xla_profile_version
return compile_options