mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Change --jax_xla_profile_version definition to config.
Changing the flag to a config permits more contained testing. This is in preparation for an upcoming change to incorporate AutoFDO profile versions in the cache key. Testing: test workload. PiperOrigin-RevId: 554942573
This commit is contained in:
parent
3e50fea29e
commit
d01695c746
@ -1081,6 +1081,14 @@ config.define_bool_state(
|
||||
'work under pmap/pjit.')
|
||||
)
|
||||
|
||||
jax_xla_profile_version = config.define_int_state(
|
||||
name='jax_xla_profile_version',
|
||||
default=0,
|
||||
help=('Optional profile version for XLA compilation. This is meaningful '
|
||||
'only when XLA is configured to support the remote compilation '
|
||||
'profile feature.')
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
|
@ -38,7 +38,7 @@ import numpy as np
|
||||
from jax._src import lib
|
||||
from jax._src import distributed
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import bool_env, config, int_env
|
||||
from jax._src.config import bool_env, config
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -78,11 +78,6 @@ _DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
|
||||
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
||||
'Optional profile version for XLA compilation. '
|
||||
'This is meaningful only when XLA is configured to '
|
||||
'support the remote compilation profile feature.')
|
||||
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'jax_cuda_visible_devices', 'all',
|
||||
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
||||
@ -175,7 +170,7 @@ def get_compile_options(
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
compile_options.profile_version = _XLA_PROFILE_VERSION.value
|
||||
compile_options.profile_version = config.jax_xla_profile_version
|
||||
return compile_options
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user