Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 659741822
This commit is contained in:
Yue Sheng 2024-08-05 17:47:34 -07:00 committed by jax authors
parent 0ab4d68511
commit f255fb700a
2 changed files with 5 additions and 3 deletions

View File

@ -15,6 +15,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -108,11 +108,9 @@ CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
),
)
# TODO(yueshengys): turn default back to True after resolving memory increase
# issue.
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
name="jax_cpu_enable_async_dispatch",
default=False,
default=True,
help="Only applies to non-parallel computations. If False, run computations"
"inline without async dispatch.",
)