Add jax_xla_profile_version configuration

A new config named jax_xla_profile_version is added
to support XLA compilation profile.

PiperOrigin-RevId: 449276852
This commit is contained in:
Hyojun Kim 2022-05-17 11:44:14 -07:00 committed by jax authors
parent 1bcb5e073c
commit bc5a5e17a5

View File

@ -30,7 +30,7 @@ from absl import logging
logging._warn_preinit_stderr = 0
import jax._src.lib
from jax._src.config import flags, bool_env
from jax._src.config import flags, bool_env, int_env
from jax._src.lib import tpu_driver_client
from jax._src.lib import xla_client
from jax._src import util, traceback_util
@ -79,6 +79,11 @@ flags.DEFINE_bool(
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_integer(
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
'Optional profile version for XLA compilation. '
'This is meaningful only when XLA is configured to '
'support the remote compilation profile feature.')
def get_compile_options(
num_replicas: int,
@ -152,6 +157,8 @@ def get_compile_options(
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
if jax._src.lib.xla_extension_version >= 68:
compile_options.profile_version = FLAGS.jax_xla_profile_version
return compile_options