tweak xla_bridge.py flags

* add environment variables for jax_disable_most_optimizations and
  jax_cpu_backend_variant
* comment on the default values in help strings
This commit is contained in:
Matthew Johnson 2021-07-02 10:31:28 -07:00
parent c97d63dec3
commit 5e92faccbb

View File

@ -30,7 +30,7 @@ from absl import logging
logging._warn_preinit_stderr = 0
import jax.lib
from .._src.config import flags
from .._src.config import flags, bool_env
from . import tpu_driver_client
from . import xla_client
from jax._src import util, traceback_util
@ -52,21 +52,26 @@ flags.DEFINE_string(
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
flags.DEFINE_string(
'jax_backend_target', 'local',
'Either "local" or "rpc:address" to connect to a remote service target.')
'Either "local" or "rpc:address" to connect to a remote service target. '
'The default is "local".')
flags.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', ''),
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
'available, but fall back to CPU otherwise. To set the platform manually, '
'pass "cpu" for CPU, "gpu" for GPU, etc.')
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
'setting the platform name to "cpu" can silence warnings that appear with '
'the default setting.')
flags.DEFINE_bool(
'jax_disable_most_optimizations', False,
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_string(
'jax_cpu_backend_variant', 'tfrt',
'jax_cpu_backend_variant selects cpu backend variant: stream_executor or '
'tfrt')
'jax_cpu_backend_variant',
os.getenv('JAX_CPU_BACKEND_VARIANT', 'tfrt'),
'Selects CPU backend runtime variant: "stream_executor" or "tfrt". The '
'default is "tfrt".')
def get_compile_options(
num_replicas: int,