Change JAX_PLATFORMS to raise an exception when platform initialization fails

This commit is contained in:
Shiva Shahrokhi 2022-06-17 16:38:56 +00:00
parent 989a3304bf
commit df8c6263de
3 changed files with 27 additions and 19 deletions

View File

@ -21,6 +21,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
* `JAX_PLATFORMS` now raises an exception when platform initialization fails.
* Changes
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and

View File

@ -602,6 +602,22 @@ jax2tf_associative_scan_reductions = config.define_bool_state(
)
)
jax_platforms = config.define_string_state(
name='jax_platforms',
default=None,
help=(
'Comma-separated list of platform names specifying which platforms jax '
'should initialize. If any of the platforms in this list are not successfully '
'initialized, an exception will be raised and the program will be aborted. '
'The first platform in the list will be the default platform. '
'For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends '
'will be initialized, and the CPU backend will be used unless otherwise '
'specified. If TPU initialization fails, it will raise an exception. '
'By default, jax will try to initialize all available '
'platforms and will default to GPU or TPU if available, and fallback to CPU '
'otherwise.'
))
enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,

View File

@ -35,6 +35,7 @@ from jax._src.config import flags, bool_env, int_env
from jax._src.lib import tpu_driver_client
from jax._src.lib import xla_client
from jax._src import util, traceback_util
from jax.config import config
import numpy as np
iree: Optional[Any]
@ -64,17 +65,6 @@ flags.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
'jax_platforms',
os.getenv('JAX_PLATFORMS', '').lower(),
'Comma-separated list of platform names specifying which platforms jax '
'should attempt to initialize. The first platform in the list that is '
'successfully initialized will be used as the default platform. For '
'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
'initialized, and the CPU backend will be used unless otherwise specified; '
'--jax_platforms=cpu means that only the CPU backend will be initialized. '
'By default, jax will try to initialize all available platforms and will '
'default to GPU or TPU if available, and fallback to CPU otherwise.')
flags.DEFINE_bool(
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
@ -297,9 +287,8 @@ def backends():
with _backend_lock:
if _backends:
return _backends
if FLAGS.jax_platforms:
jax_platforms = FLAGS.jax_platforms.split(",")
if config.jax_platforms:
jax_platforms = config.jax_platforms.split(",")
platforms = []
# Allow platform aliases in the list of platforms.
for platform in jax_platforms:
@ -310,7 +299,6 @@ def backends():
platforms_and_priorites = (
(platform, priority) for platform, (_, priority)
in _backend_factories.items())
default_priority = -1000
for platform, priority in platforms_and_priorites:
try:
@ -337,10 +325,13 @@ def backends():
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
logging.info("Unable to initialize backend '%s': %s", platform,
err)
_backends_errors[platform] = str(err)
continue
err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms:
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
logging.info(err_msg)
continue
# We don't warn about falling back to CPU on Mac OS, because we don't
# support anything else there at the moment and warning would be pointless.
if (py_platform.system() != "Darwin" and