diff --git a/CHANGELOG.md b/CHANGELOG.md index 5045de76c..7fbe947fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. + * Computations on the CPU backend may now be dispatched asynchronously in + more cases. Previously non-parallel computations were always dispatched + synchronously. You can recover the old behavior by setting + `jax.config.update('jax_cpu_enable_async_dispatch', False)`. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index d29e48ccf..c6f8684d2 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -108,11 +108,9 @@ CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag( ), ) -# TODO(yueshengys): turn default back to True after resolving memory increase -# issue. _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( name="jax_cpu_enable_async_dispatch", - default=False, + default=True, help="Only applies to non-parallel computations. If False, run computations" "inline without async dispatch.", )