mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Enable C++ pmap.
On CPU: ``` name old cpu/op new cpu/op delta pmap_trivial_2_devices 128µs ± 6% 14µs ± 3% -89.06% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 212µs ± 2% 35µs ± 1% -83.54% (p=0.008 n=5+5) pmap_trivial_8_devices 215µs ± 1% 40µs ± 4% -81.31% (p=0.008 n=5+5) pmap_simple_2_devices 123µs ± 5% 15µs ± 6% -87.70% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 211µs ± 3% 35µs ± 2% -83.24% (p=0.008 n=5+5) pmap_simple_8_devices 217µs ± 5% 40µs ± 2% -81.68% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices_100_args 5.42ms ± 7% 0.52ms ± 2% -90.44% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.5ms ±21% 17.5ms ±37% -34.18% (p=0.008 n=5+5) sda_index_1 7.45µs ± 6% 7.53µs ± 6% ~ (p=0.222 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) name old time/op new time/op delta pmap_trivial_2_devices 136µs ± 8% 19µs ± 3% -86.08% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 216µs ± 3% 39µs ± 2% -81.94% (p=0.008 n=5+5) pmap_trivial_8_devices 219µs ± 2% 49µs ±38% -77.67% (p=0.008 n=5+5) pmap_simple_2_devices 130µs ± 5% 20µs ± 5% -84.38% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 216µs ± 3% 39µs ± 5% -81.71% (p=0.008 n=5+5) pmap_simple_8_devices 221µs ± 6% 43µs ± 1% -80.41% (p=0.016 n=5+4) pmap_simple_dispatch_8_devices_100_args 5.52ms ± 7% 0.59ms ± 2% -89.28% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.6ms ±21% 17.6ms ±37% -34.04% (p=0.008 n=5+5) sda_index_1 7.48µs ± 8% 7.53µs ± 6% ~ (p=0.310 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) ``` PiperOrigin-RevId: 401274089
This commit is contained in:
parent
b9b0ce5d7c
commit
803b83ee15
@ -28,6 +28,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
unhashable static arguments into the function object.
|
||||
* `jax.util.partial` was an accidental export that has now been removed. Use
|
||||
`functools.partial` from the Python standard library instead.
|
||||
* New features:
|
||||
* A C++ code-path improving the dispatch time for pmap is now the default when
|
||||
using jaxlib 0.1.72 or newer. The feature can be disabled using the
|
||||
`--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
|
||||
It improves dispatch time,
|
||||
|
||||
## jax 0.2.21 (Sept 23, 2021)
|
||||
* [GitHub
|
||||
|
@ -104,12 +104,13 @@ FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
|
||||
"A temporary flag enabling the C++ jax.jit fast path."
|
||||
"A flag enabling the C++ jax.jit fast path."
|
||||
"Set this to `False` only if it crashes otherwise and report "
|
||||
"the error to the jax-team.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False),
|
||||
"A temporary flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"experimental_cpp_pmap",
|
||||
bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39),
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user