mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Raise an error if jax.config.update('jax_num_cpu_devices', val)
is called after backend is initialized
PiperOrigin-RevId: 736646012
This commit is contained in:
parent
47bf22e37d
commit
e1b62cede1
@ -564,6 +564,7 @@ def int_state(
|
||||
update_global_hook: Callable[[int], None] | None = None,
|
||||
update_thread_local_hook: Callable[[int | None], None] | None = None,
|
||||
include_in_jit_key: bool = False,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
) -> State[int]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -596,6 +597,8 @@ def int_state(
|
||||
if new_val is not None and not isinstance(new_val, int):
|
||||
raise ValueError(f'new int config value must be None or of type int, '
|
||||
f'got {new_val} of type {type(new_val)}')
|
||||
if new_val is not None and validator is not None:
|
||||
validator(new_val)
|
||||
|
||||
s = State[int](name, default, help, update_global_hook,
|
||||
update_thread_local_hook, validate,
|
||||
@ -1804,15 +1807,6 @@ cpu_collectives_implementation = optional_enum_state(
|
||||
'("gloo", "mpi")'),
|
||||
)
|
||||
|
||||
num_cpu_devices = int_state(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help=(
|
||||
"Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized."),
|
||||
)
|
||||
|
||||
enable_empty_arrays = bool_state(
|
||||
name='jax_enable_empty_arrays',
|
||||
default=False,
|
||||
|
@ -550,9 +550,14 @@ def request_cpu_devices(nr_devices: int):
|
||||
invoked. Test cases that require a specific number of devices should skip
|
||||
themselves if that number is not met.
|
||||
"""
|
||||
if config.num_cpu_devices.value < nr_devices:
|
||||
if xla_bridge.num_cpu_devices.value < nr_devices:
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
config.update("jax_num_cpu_devices", nr_devices)
|
||||
# Don't raise an error for `request_cpu_devices` because we initialize the
|
||||
# backend in OSS during collecting tests in pytest via `device_under_test`.
|
||||
try:
|
||||
config.update("jax_num_cpu_devices", nr_devices)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def skip_on_flag(flag_name, skip_value):
|
||||
|
@ -263,7 +263,7 @@ def make_cpu_client(
|
||||
# Already validated by config module
|
||||
assert collectives_impl is None
|
||||
|
||||
num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None
|
||||
num_devices = num_cpu_devices.value if num_cpu_devices.value >= 0 else None
|
||||
return xla_client.make_cpu_client(
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
@ -1273,3 +1273,20 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs):
|
||||
return xla_client.make_tfrt_tpu_c_api_device_topology(
|
||||
topology_name, **kwargs
|
||||
)
|
||||
|
||||
def _validate_backend_not_initialized(new_val):
|
||||
if backends_are_initialized():
|
||||
raise RuntimeError(
|
||||
"jax_num_cpu_devices config should be updated before backends are"
|
||||
" initialized i.e. before any JAX operation is executed. You should"
|
||||
" initialize this config immediately after `import jax`.")
|
||||
|
||||
num_cpu_devices = config.int_state(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help=(
|
||||
"Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized."),
|
||||
validator=_validate_backend_not_initialized,
|
||||
)
|
||||
|
@ -4438,6 +4438,14 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"):
|
||||
jax.device_put(jnp.arange(8), 'cpu')
|
||||
|
||||
def test_num_cpu_devices_called_after_initialization(self):
|
||||
jax.devices()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"jax_num_cpu_devices config should be updated before backends are "
|
||||
"initialized"):
|
||||
config.update('jax_num_cpu_devices', 2)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_clear_cache(self):
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user