diff --git a/CHANGELOG.md b/CHANGELOG.md index becdd9b0c..271a8f597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,9 +18,24 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.21 * Changes + * In preparation for adding distributed CPU support, JAX now treats CPU + devices identically to GPU and TPU devices, that is: + + * `jax.devices()` includes all devices present in a distributed job, even + those not local to the current process. `jax.local_devices()` still only + includes devices local to the current process, so if the change to + `jax.devices()` breaks you, you most likely want to use + `jax.local_devices()` instead. + * CPU devices now receive a globally unique ID number within a distributed + job; previously CPU devices would receive a process-local ID number. + * The `process_index` of each CPU device will now match any GPU or TPU + devices within the same process; previously the `process_index` of a CPU + device was always 0. + * On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version. + ## jax 0.4.20 (Nov 2, 2023) ## jaxlib 0.4.20 (Nov 2, 2023) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 0253cd067..3b147d41e 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -204,8 +204,20 @@ def register_backend_factory(name: str, factory: BackendFactory, *, factory, priority, fail_quietly, experimental) +def make_cpu_client() -> xla_client.Client: + if xla_extension_version >= 216: + # TODO(phawkins): remove type: ignore after updating jaxlib version used for + # mypy checks. + return xla_client.make_cpu_client( # type: ignore + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + ) + return xla_client.make_cpu_client() + + register_backend_factory( - "cpu", xla_client.make_cpu_client, priority=0, fail_quietly=False + "cpu", make_cpu_client, priority=0, fail_quietly=False )