mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
tweak xla_bridge.py flags
* add environment variables for jax_disable_most_optimizations and jax_cpu_backend_variant * comment on the default values in help strings
This commit is contained in:
parent
c97d63dec3
commit
5e92faccbb
@ -30,7 +30,7 @@ from absl import logging
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax.lib
|
||||
from .._src.config import flags
|
||||
from .._src.config import flags, bool_env
|
||||
from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
@ -52,21 +52,26 @@ flags.DEFINE_string(
|
||||
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
|
||||
flags.DEFINE_string(
|
||||
'jax_backend_target', 'local',
|
||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||
'Either "local" or "rpc:address" to connect to a remote service target. '
|
||||
'The default is "local".')
|
||||
flags.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', ''),
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
|
||||
'available, but fall back to CPU otherwise. To set the platform manually, '
|
||||
'pass "cpu" for CPU, "gpu" for GPU, etc.')
|
||||
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
|
||||
'setting the platform name to "cpu" can silence warnings that appear with '
|
||||
'the default setting.')
|
||||
flags.DEFINE_bool(
|
||||
'jax_disable_most_optimizations', False,
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
flags.DEFINE_string(
|
||||
'jax_cpu_backend_variant', 'tfrt',
|
||||
'jax_cpu_backend_variant selects cpu backend variant: stream_executor or '
|
||||
'tfrt')
|
||||
'jax_cpu_backend_variant',
|
||||
os.getenv('JAX_CPU_BACKEND_VARIANT', 'tfrt'),
|
||||
'Selects CPU backend runtime variant: "stream_executor" or "tfrt". The '
|
||||
'default is "tfrt".')
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user