mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Async dispatch expensive computations on the JAX CPU backend.
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive. Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs. PiperOrigin-RevId: 625064815
This commit is contained in:
parent
eb92a5c711
commit
64775d02a3
@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||
* Async dispatch expensive computations on the CPU backend. This only applies
|
||||
to non-parallel computations, as we already do async dispatch for parallel
|
||||
computations. You can recover the old behavior by setting
|
||||
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
|
@ -106,6 +106,13 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
|
||||
help="If True, enable cross-process collectives on CPU using Gloo.",
|
||||
)
|
||||
|
||||
_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool(
|
||||
name="jax_cpu_enable_async_dispatch",
|
||||
default=True,
|
||||
help="Only applies to non-parallel computations. If False, run computations"
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
||||
|
||||
# Warn the user if they call fork(), because it's not going to go well for them.
|
||||
def _at_fork():
|
||||
@ -224,6 +231,14 @@ def make_cpu_client() -> xla_client.Client:
|
||||
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
|
||||
distributed_client=distributed.global_state.client,
|
||||
)
|
||||
if xla_extension_version >= 257:
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
collectives=collectives,
|
||||
)
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user