mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964
PiperOrigin-RevId: 582275667
This commit is contained in:
parent
7cf66dfe4b
commit
ef9075159a
@ -18,14 +18,6 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
|
||||
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
|
||||
|
||||
* Bug fixes
|
||||
* Distributed JAX programs will now receive a globally unique ID number for
|
||||
CPU devices. Previously, in a distributed GPU or TPU job, the CPU devices
|
||||
attached to each process did not receive unique global ID numbers.
|
||||
In addition the `process_index` attached to each CPU device will now match
|
||||
any GPU or TPU devices.
|
||||
|
||||
|
||||
## jax 0.4.20 (Nov 2, 2023)
|
||||
|
||||
## jaxlib 0.4.20 (Nov 2, 2023)
|
||||
|
@ -204,20 +204,8 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
|
||||
factory, priority, fail_quietly, experimental)
|
||||
|
||||
|
||||
def make_cpu_client() -> xla_client.Client:
|
||||
if xla_extension_version >= 214:
|
||||
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
|
||||
# mypy checks.
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
)
|
||||
return xla_client.make_cpu_client()
|
||||
|
||||
|
||||
register_backend_factory(
|
||||
"cpu", make_cpu_client, priority=0, fail_quietly=False
|
||||
"cpu", xla_client.make_cpu_client, priority=0, fail_quietly=False
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user