mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Change JAX_PLATFORMS to raise an exception when platform initialization fails
This commit is contained in:
parent
989a3304bf
commit
df8c6263de
@ -21,6 +21,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Breaking changes
|
||||
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
|
||||
`max_cache_size_ bytes` anymore and will not get that as an input.
|
||||
* `JAX_PLATFORMS` now raises an exception when platform initialization fails.
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
|
||||
that allows selection between an LU-decomposition based implementation and
|
||||
|
@ -602,6 +602,22 @@ jax2tf_associative_scan_reductions = config.define_bool_state(
|
||||
)
|
||||
)
|
||||
|
||||
jax_platforms = config.define_string_state(
|
||||
name='jax_platforms',
|
||||
default=None,
|
||||
help=(
|
||||
'Comma-separated list of platform names specifying which platforms jax '
|
||||
'should initialize. If any of the platforms in this list are not successfully '
|
||||
'initialized, an exception will be raised and the program will be aborted. '
|
||||
'The first platform in the list will be the default platform. '
|
||||
'For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends '
|
||||
'will be initialized, and the CPU backend will be used unless otherwise '
|
||||
'specified. If TPU initialization fails, it will raise an exception. '
|
||||
'By default, jax will try to initialize all available '
|
||||
'platforms and will default to GPU or TPU if available, and fallback to CPU '
|
||||
'otherwise.'
|
||||
))
|
||||
|
||||
enable_checks = config.define_bool_state(
|
||||
name='jax_enable_checks',
|
||||
default=False,
|
||||
|
@ -35,6 +35,7 @@ from jax._src.config import flags, bool_env, int_env
|
||||
from jax._src.lib import tpu_driver_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
iree: Optional[Any]
|
||||
@ -64,17 +65,6 @@ flags.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Deprecated, please use --jax_platforms instead.')
|
||||
flags.DEFINE_string(
|
||||
'jax_platforms',
|
||||
os.getenv('JAX_PLATFORMS', '').lower(),
|
||||
'Comma-separated list of platform names specifying which platforms jax '
|
||||
'should attempt to initialize. The first platform in the list that is '
|
||||
'successfully initialized will be used as the default platform. For '
|
||||
'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
|
||||
'initialized, and the CPU backend will be used unless otherwise specified; '
|
||||
'--jax_platforms=cpu means that only the CPU backend will be initialized. '
|
||||
'By default, jax will try to initialize all available platforms and will '
|
||||
'default to GPU or TPU if available, and fallback to CPU otherwise.')
|
||||
flags.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
@ -297,9 +287,8 @@ def backends():
|
||||
with _backend_lock:
|
||||
if _backends:
|
||||
return _backends
|
||||
|
||||
if FLAGS.jax_platforms:
|
||||
jax_platforms = FLAGS.jax_platforms.split(",")
|
||||
if config.jax_platforms:
|
||||
jax_platforms = config.jax_platforms.split(",")
|
||||
platforms = []
|
||||
# Allow platform aliases in the list of platforms.
|
||||
for platform in jax_platforms:
|
||||
@ -310,7 +299,6 @@ def backends():
|
||||
platforms_and_priorites = (
|
||||
(platform, priority) for platform, (_, priority)
|
||||
in _backend_factories.items())
|
||||
|
||||
default_priority = -1000
|
||||
for platform, priority in platforms_and_priorites:
|
||||
try:
|
||||
@ -337,10 +325,13 @@ def backends():
|
||||
else:
|
||||
# If the backend isn't built into the binary, or if it has no devices,
|
||||
# we expect a RuntimeError.
|
||||
logging.info("Unable to initialize backend '%s': %s", platform,
|
||||
err)
|
||||
_backends_errors[platform] = str(err)
|
||||
continue
|
||||
err_msg = f"Unable to initialize backend '{platform}': {err}"
|
||||
if config.jax_platforms:
|
||||
raise RuntimeError(err_msg)
|
||||
else:
|
||||
_backends_errors[platform] = str(err)
|
||||
logging.info(err_msg)
|
||||
continue
|
||||
# We don't warn about falling back to CPU on Mac OS, because we don't
|
||||
# support anything else there at the moment and warning would be pointless.
|
||||
if (py_platform.system() != "Darwin" and
|
||||
|
Loading…
x
Reference in New Issue
Block a user