diff --git a/jax/_src/config.py b/jax/_src/config.py index 1e46fb8bd..855740bf2 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -564,6 +564,7 @@ def int_state( update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, include_in_jit_key: bool = False, + validator: Callable[[Any], None] | None = None, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -596,6 +597,8 @@ def int_state( if new_val is not None and not isinstance(new_val, int): raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') + if new_val is not None and validator is not None: + validator(new_val) s = State[int](name, default, help, update_global_hook, update_thread_local_hook, validate, @@ -1804,15 +1807,6 @@ cpu_collectives_implementation = optional_enum_state( '("gloo", "mpi")'), ) -num_cpu_devices = int_state( - name="jax_num_cpu_devices", - default=-1, - help=( - "Number of CPU devices to use. If not provided, the value of " - "the XLA flag --xla_force_host_platform_device_count is used." - " Must be set before JAX is initialized."), -) - enable_empty_arrays = bool_state( name='jax_enable_empty_arrays', default=False, diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 66ee882dc..7db88f447 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -550,9 +550,14 @@ def request_cpu_devices(nr_devices: int): invoked. Test cases that require a specific number of devices should skip themselves if that number is not met. """ - if config.num_cpu_devices.value < nr_devices: + if xla_bridge.num_cpu_devices.value < nr_devices: xla_bridge.get_backend.cache_clear() - config.update("jax_num_cpu_devices", nr_devices) + # Don't raise an error for `request_cpu_devices` because we initialize the + # backend in OSS during collecting tests in pytest via `device_under_test`. + try: + config.update("jax_num_cpu_devices", nr_devices) + except RuntimeError: + pass def skip_on_flag(flag_name, skip_value): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 33b46145d..be96deab8 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -263,7 +263,7 @@ def make_cpu_client( # Already validated by config module assert collectives_impl is None - num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None + num_devices = num_cpu_devices.value if num_cpu_devices.value >= 0 else None return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client, @@ -1273,3 +1273,20 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): return xla_client.make_tfrt_tpu_c_api_device_topology( topology_name, **kwargs ) + +def _validate_backend_not_initialized(new_val): + if backends_are_initialized(): + raise RuntimeError( + "jax_num_cpu_devices config should be updated before backends are" + " initialized i.e. before any JAX operation is executed. You should" + " initialize this config immediately after `import jax`.") + +num_cpu_devices = config.int_state( + name="jax_num_cpu_devices", + default=-1, + help=( + "Number of CPU devices to use. If not provided, the value of " + "the XLA flag --xla_force_host_platform_device_count is used." + " Must be set before JAX is initialized."), + validator=_validate_backend_not_initialized, +) diff --git a/tests/api_test.py b/tests/api_test.py index c9cf28e0a..aece7b19f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4438,6 +4438,14 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): jax.device_put(jnp.arange(8), 'cpu') + def test_num_cpu_devices_called_after_initialization(self): + jax.devices() + with self.assertRaisesRegex( + RuntimeError, + "jax_num_cpu_devices config should be updated before backends are " + "initialized"): + config.update('jax_num_cpu_devices', 2) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_clear_cache(self): @jax.jit