mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be False
by default, as we observed abnormal memory usage increases.
PiperOrigin-RevId: 625448228
This commit is contained in:
parent
47815c54d0
commit
1f83908bae
@ -106,9 +106,11 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
|
||||
help="If True, enable cross-process collectives on CPU using Gloo.",
|
||||
)
|
||||
|
||||
# TODO(yueshengys): turn default back to True after resolving memory increase
|
||||
# issue.
|
||||
_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool(
|
||||
name="jax_cpu_enable_async_dispatch",
|
||||
default=True,
|
||||
default=False,
|
||||
help="Only applies to non-parallel computations. If False, run computations"
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user