From 803b83ee155eb76fe3d9db04e084e365e78e6a0b Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Wed, 6 Oct 2021 10:07:41 -0700 Subject: [PATCH] Enable C++ pmap. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On CPU: ``` name old cpu/op new cpu/op delta pmap_trivial_2_devices 128µs ± 6% 14µs ± 3% -89.06% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 212µs ± 2% 35µs ± 1% -83.54% (p=0.008 n=5+5) pmap_trivial_8_devices 215µs ± 1% 40µs ± 4% -81.31% (p=0.008 n=5+5) pmap_simple_2_devices 123µs ± 5% 15µs ± 6% -87.70% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 211µs ± 3% 35µs ± 2% -83.24% (p=0.008 n=5+5) pmap_simple_8_devices 217µs ± 5% 40µs ± 2% -81.68% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices_100_args 5.42ms ± 7% 0.52ms ± 2% -90.44% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.5ms ±21% 17.5ms ±37% -34.18% (p=0.008 n=5+5) sda_index_1 7.45µs ± 6% 7.53µs ± 6% ~ (p=0.222 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) name old time/op new time/op delta pmap_trivial_2_devices 136µs ± 8% 19µs ± 3% -86.08% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 216µs ± 3% 39µs ± 2% -81.94% (p=0.008 n=5+5) pmap_trivial_8_devices 219µs ± 2% 49µs ±38% -77.67% (p=0.008 n=5+5) pmap_simple_2_devices 130µs ± 5% 20µs ± 5% -84.38% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 216µs ± 3% 39µs ± 5% -81.71% (p=0.008 n=5+5) pmap_simple_8_devices 221µs ± 6% 43µs ± 1% -80.41% (p=0.016 n=5+4) pmap_simple_dispatch_8_devices_100_args 5.52ms ± 7% 0.59ms ± 2% -89.28% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.6ms ±21% 17.6ms ±37% -34.04% (p=0.008 n=5+5) sda_index_1 7.48µs ± 8% 7.53µs ± 6% ~ (p=0.310 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) ``` PiperOrigin-RevId: 401274089 --- CHANGELOG.md | 5 +++++ jax/_src/api.py | 7 ++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d705f34..f70261789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. unhashable static arguments into the function object. * `jax.util.partial` was an accidental export that has now been removed. Use `functools.partial` from the Python standard library instead. +* New features: + * A C++ code-path improving the dispatch time for pmap is now the default when + using jaxlib 0.1.72 or newer. The feature can be disabled using the + `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable). + It improves dispatch time, ## jax 0.2.21 (Sept 23, 2021) * [GitHub diff --git a/jax/_src/api.py b/jax/_src/api.py index a2cb841cb..8be9e02ad 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -104,12 +104,13 @@ FLAGS = flags.FLAGS flags.DEFINE_bool( "experimental_cpp_jit", bool_env("JAX_CPP_JIT", True), - "A temporary flag enabling the C++ jax.jit fast path." + "A flag enabling the C++ jax.jit fast path." "Set this to `False` only if it crashes otherwise and report " "the error to the jax-team.") flags.DEFINE_bool( - "experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False), - "A temporary flag enabling the C++ jax.pmap fast path. Until the default " + "experimental_cpp_pmap", + bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39), + "A flag enabling the C++ jax.pmap fast path. Until the default " "is switched to True, the feature is not supported and possibly broken " "(e.g. it may use unreleased code from jaxlib.")