mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046
PiperOrigin-RevId: 582789416
This commit is contained in:
parent
0560cc478e
commit
234be736c4
15
CHANGELOG.md
15
CHANGELOG.md
@ -18,9 +18,24 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jaxlib 0.4.21
|
||||
|
||||
* Changes
|
||||
* In preparation for adding distributed CPU support, JAX now treats CPU
|
||||
devices identically to GPU and TPU devices, that is:
|
||||
|
||||
* `jax.devices()` includes all devices present in a distributed job, even
|
||||
those not local to the current process. `jax.local_devices()` still only
|
||||
includes devices local to the current process, so if the change to
|
||||
`jax.devices()` breaks you, you most likely want to use
|
||||
`jax.local_devices()` instead.
|
||||
* CPU devices now receive a globally unique ID number within a distributed
|
||||
job; previously CPU devices would receive a process-local ID number.
|
||||
* The `process_index` of each CPU device will now match any GPU or TPU
|
||||
devices within the same process; previously the `process_index` of a CPU
|
||||
device was always 0.
|
||||
|
||||
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
|
||||
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
|
||||
|
||||
|
||||
## jax 0.4.20 (Nov 2, 2023)
|
||||
|
||||
## jaxlib 0.4.20 (Nov 2, 2023)
|
||||
|
@ -204,8 +204,20 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
|
||||
factory, priority, fail_quietly, experimental)
|
||||
|
||||
|
||||
def make_cpu_client() -> xla_client.Client:
|
||||
if xla_extension_version >= 216:
|
||||
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
|
||||
# mypy checks.
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
)
|
||||
return xla_client.make_cpu_client()
|
||||
|
||||
|
||||
register_backend_factory(
|
||||
"cpu", xla_client.make_cpu_client, priority=0, fail_quietly=False
|
||||
"cpu", make_cpu_client, priority=0, fail_quietly=False
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user