mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:16:05 +00:00
Add jax_xla_profile_version configuration
A new config named jax_xla_profile_version is added to support XLA compilation profile. PiperOrigin-RevId: 449276852
This commit is contained in:
parent
1bcb5e073c
commit
bc5a5e17a5
@ -30,7 +30,7 @@ from absl import logging
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax._src.lib
|
||||
from jax._src.config import flags, bool_env
|
||||
from jax._src.config import flags, bool_env, int_env
|
||||
from jax._src.lib import tpu_driver_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
@ -79,6 +79,11 @@ flags.DEFINE_bool(
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
flags.DEFINE_integer(
|
||||
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
||||
'Optional profile version for XLA compilation. '
|
||||
'This is meaningful only when XLA is configured to '
|
||||
'support the remote compilation profile feature.')
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
@ -152,6 +157,8 @@ def get_compile_options(
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
if jax._src.lib.xla_extension_version >= 68:
|
||||
compile_options.profile_version = FLAGS.jax_xla_profile_version
|
||||
return compile_options
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user