Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964

PiperOrigin-RevId: 582275667
This commit is contained in:
Peter Hawkins 2023-11-14 04:31:15 -08:00 committed by jax authors
parent 7cf66dfe4b
commit ef9075159a
2 changed files with 1 additions and 21 deletions

View File

@ -18,14 +18,6 @@ Remember to align the itemized text with the first line of an item within a list
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
* Bug fixes
* Distributed JAX programs will now receive a globally unique ID number for
CPU devices. Previously, in a distributed GPU or TPU job, the CPU devices
attached to each process did not receive unique global ID numbers.
In addition the `process_index` attached to each CPU device will now match
any GPU or TPU devices.
## jax 0.4.20 (Nov 2, 2023)
## jaxlib 0.4.20 (Nov 2, 2023)

View File

@ -204,20 +204,8 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
factory, priority, fail_quietly, experimental)
def make_cpu_client() -> xla_client.Client:
if xla_extension_version >= 214:
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
# mypy checks.
return xla_client.make_cpu_client( # type: ignore
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
)
return xla_client.make_cpu_client()
register_backend_factory(
"cpu", make_cpu_client, priority=0, fail_quietly=False
"cpu", xla_client.make_cpu_client, priority=0, fail_quietly=False
)