Enable C++ pmap.

On CPU:
```
name                                     old cpu/op  new cpu/op  delta
pmap_trivial_2_devices                    128µs ± 6%    14µs ± 3%  -89.06%  (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           212µs ± 2%    35µs ± 1%  -83.54%  (p=0.008 n=5+5)
pmap_trivial_8_devices                    215µs ± 1%    40µs ± 4%  -81.31%  (p=0.008 n=5+5)
pmap_simple_2_devices                     123µs ± 5%    15µs ± 6%  -87.70%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            211µs ± 3%    35µs ± 2%  -83.24%  (p=0.008 n=5+5)
pmap_simple_8_devices                     217µs ± 5%    40µs ± 2%  -81.68%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices_100_args  5.42ms ± 7%  0.52ms ± 2%  -90.44%  (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.5ms ±21%  17.5ms ±37%  -34.18%  (p=0.008 n=5+5)
sda_index_1                              7.45µs ± 6%  7.53µs ± 6%     ~     (p=0.222 n=5+5)
sda_index_2                              14.1µs ± 1%  14.3µs ± 4%     ~     (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%  56.9µs ± 4%     ~     (p=0.310 n=5+5)

name                                     old time/op             new time/op             delta
pmap_trivial_2_devices                    136µs ± 8%               19µs ± 3%  -86.08%          (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           216µs ± 3%               39µs ± 2%  -81.94%          (p=0.008 n=5+5)
pmap_trivial_8_devices                    219µs ± 2%               49µs ±38%  -77.67%          (p=0.008 n=5+5)
pmap_simple_2_devices                     130µs ± 5%               20µs ± 5%  -84.38%          (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            216µs ± 3%               39µs ± 5%  -81.71%          (p=0.008 n=5+5)
pmap_simple_8_devices                     221µs ± 6%               43µs ± 1%  -80.41%          (p=0.016 n=5+4)
pmap_simple_dispatch_8_devices_100_args  5.52ms ± 7%             0.59ms ± 2%  -89.28%          (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.6ms ±21%             17.6ms ±37%  -34.04%          (p=0.008 n=5+5)
sda_index_1                              7.48µs ± 8%             7.53µs ± 6%     ~             (p=0.310 n=5+5)
sda_index_2                              14.1µs ± 1%             14.3µs ± 4%     ~             (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%             56.9µs ± 4%     ~             (p=0.310 n=5+5)
```

PiperOrigin-RevId: 401274089
This commit is contained in:
Jean-Baptiste Lespiau 2021-10-06 10:07:41 -07:00 committed by jax authors
parent b9b0ce5d7c
commit 803b83ee15
2 changed files with 9 additions and 3 deletions

View File

@ -28,6 +28,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
* New features:
* A C++ code-path improving the dispatch time for pmap is now the default when
using jaxlib 0.1.72 or newer. The feature can be disabled using the
`--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
It improves dispatch time,
## jax 0.2.21 (Sept 23, 2021)
* [GitHub

View File

@ -104,12 +104,13 @@ FLAGS = flags.FLAGS
flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
"A temporary flag enabling the C++ jax.jit fast path."
"A flag enabling the C++ jax.jit fast path."
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False),
"A temporary flag enabling the C++ jax.pmap fast path. Until the default "
"experimental_cpp_pmap",
bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")