Async dispatch expensive computations on the JAX CPU backend.

Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.

Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.

PiperOrigin-RevId: 625064815
This commit is contained in:
Yue Sheng 2024-04-15 13:28:56 -07:00 committed by jax authors
parent eb92a5c711
commit 64775d02a3
2 changed files with 19 additions and 0 deletions

View File

@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* Async dispatch expensive computations on the CPU backend. This only applies
to non-parallel computations, as we already do async dispatch for parallel
computations. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

View File

@ -106,6 +106,13 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
help="If True, enable cross-process collectives on CPU using Gloo.",
)
_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool(
name="jax_cpu_enable_async_dispatch",
default=True,
help="Only applies to non-parallel computations. If False, run computations"
"inline without async dispatch.",
)
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
@ -224,6 +231,14 @@ def make_cpu_client() -> xla_client.Client:
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
distributed_client=distributed.global_state.client,
)
if xla_extension_version >= 257:
return xla_client.make_cpu_client( # type: ignore
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
collectives=collectives,
)
return xla_client.make_cpu_client( # type: ignore
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,