Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046

PiperOrigin-RevId: 582789416
This commit is contained in:
Peter Hawkins 2023-11-15 13:35:14 -08:00 committed by jax authors
parent 0560cc478e
commit 234be736c4
2 changed files with 28 additions and 1 deletions

View File

@ -18,9 +18,24 @@ Remember to align the itemized text with the first line of an item within a list
## jaxlib 0.4.21
* Changes
* In preparation for adding distributed CPU support, JAX now treats CPU
devices identically to GPU and TPU devices, that is:
* `jax.devices()` includes all devices present in a distributed job, even
those not local to the current process. `jax.local_devices()` still only
includes devices local to the current process, so if the change to
`jax.devices()` breaks you, you most likely want to use
`jax.local_devices()` instead.
* CPU devices now receive a globally unique ID number within a distributed
job; previously CPU devices would receive a process-local ID number.
* The `process_index` of each CPU device will now match any GPU or TPU
devices within the same process; previously the `process_index` of a CPU
device was always 0.
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
## jax 0.4.20 (Nov 2, 2023)
## jaxlib 0.4.20 (Nov 2, 2023)

View File

@ -204,8 +204,20 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
factory, priority, fail_quietly, experimental)
def make_cpu_client() -> xla_client.Client:
if xla_extension_version >= 216:
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
# mypy checks.
return xla_client.make_cpu_client( # type: ignore
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
)
return xla_client.make_cpu_client()
register_backend_factory(
"cpu", xla_client.make_cpu_client, priority=0, fail_quietly=False
"cpu", make_cpu_client, priority=0, fail_quietly=False
)