mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False)
, one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
This commit is contained in:
parent
0ab4d68511
commit
f255fb700a
@ -15,6 +15,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Changes
|
||||
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
|
||||
See {ref}`python-array-api` for more information.
|
||||
* Computations on the CPU backend may now be dispatched asynchronously in
|
||||
more cases. Previously non-parallel computations were always dispatched
|
||||
synchronously. You can recover the old behavior by setting
|
||||
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
|
||||
|
||||
* Breaking changes
|
||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||
|
@ -108,11 +108,9 @@ CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
|
||||
),
|
||||
)
|
||||
|
||||
# TODO(yueshengys): turn default back to True after resolving memory increase
|
||||
# issue.
|
||||
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
name="jax_cpu_enable_async_dispatch",
|
||||
default=False,
|
||||
default=True,
|
||||
help="Only applies to non-parallel computations. If False, run computations"
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user