diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d705f34..f70261789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. unhashable static arguments into the function object. * `jax.util.partial` was an accidental export that has now been removed. Use `functools.partial` from the Python standard library instead. +* New features: + * A C++ code-path improving the dispatch time for pmap is now the default when + using jaxlib 0.1.72 or newer. The feature can be disabled using the + `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable). + It improves dispatch time, ## jax 0.2.21 (Sept 23, 2021) * [GitHub diff --git a/jax/_src/api.py b/jax/_src/api.py index a2cb841cb..8be9e02ad 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -104,12 +104,13 @@ FLAGS = flags.FLAGS flags.DEFINE_bool( "experimental_cpp_jit", bool_env("JAX_CPP_JIT", True), - "A temporary flag enabling the C++ jax.jit fast path." + "A flag enabling the C++ jax.jit fast path." "Set this to `False` only if it crashes otherwise and report " "the error to the jax-team.") flags.DEFINE_bool( - "experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False), - "A temporary flag enabling the C++ jax.pmap fast path. Until the default " + "experimental_cpp_pmap", + bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39), + "A flag enabling the C++ jax.pmap fast path. Until the default " "is switched to True, the feature is not supported and possibly broken " "(e.g. it may use unreleased code from jaxlib.")