mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: Add a FAQ on "Benchmarking JAX code"
This should a useful reference for users surprised by how JAX is slow :)
This commit is contained in:
parent
9bffe1ad05
commit
46db7bbe3b
80
docs/faq.rst
80
docs/faq.rst
@ -2,11 +2,9 @@ JAX Frequently Asked Questions (FAQ)
|
||||
====================================
|
||||
|
||||
.. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html
|
||||
.. comment Some links referenced here. Use JAX_sharp_bits_ (underscore at the end) to reference
|
||||
.. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference
|
||||
|
||||
|
||||
.. _JAX_sharp_bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
|
||||
.. _How_JAX_primitives_work: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
|
||||
.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
|
||||
|
||||
We are collecting here answers to frequently asked questions.
|
||||
Contributions welcome!
|
||||
@ -51,7 +49,7 @@ with the same first value of ``y``.
|
||||
|
||||
Additional reading:
|
||||
|
||||
* JAX_sharp_bits_
|
||||
* `JAX - The Sharp Bits`_
|
||||
|
||||
.. _faq-slow-compile:
|
||||
|
||||
@ -144,6 +142,78 @@ For a worked-out example, we recommend reading through
|
||||
``test_computation_follows_data`` in
|
||||
`multi_device_test.py <https://github.com/google/jax/blob/master/tests/multi_device_test.py>`_.
|
||||
|
||||
.. _faq-benchmark:
|
||||
|
||||
Benchmarking JAX code
|
||||
---------------------
|
||||
|
||||
You just ported a tricky function from NumPy/SciPy to JAX. Did that actuallly
|
||||
speed things up?
|
||||
|
||||
Keep in mind these important differences from NumPy when measuring the
|
||||
speed of code using JAX:
|
||||
|
||||
1. **JAX code is Just-In-Time (JIT) compiled.** Most code written in JAX can be
|
||||
written in such a way that it supports JIT compilation, which can make it run
|
||||
*much faster* (see `To JIT or not to JIT`_). To get maximium performance from
|
||||
JAX, you should apply :func:`jax.jit` on your outer-most function calls.
|
||||
|
||||
Keep in mind that the first time you run JAX code, it will be slower because
|
||||
it is being compiled. This is true even if you don't use ``jit`` in your own
|
||||
code, because JAX's builtin functions are also JIT compiled.
|
||||
2. **JAX has asynchronous dispatch.** This means that you need to call
|
||||
``.block_until_ready()`` to ensure that computation has actually happened
|
||||
(see :ref:`async-dispatch`).
|
||||
3. **JAX by default only uses 32-bit dtypes.** You may want to either explicitly
|
||||
use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see
|
||||
`Double (64 bit) precision`_) for a fair comparison.
|
||||
4. **Transferring data between CPUs and accelerators takes time.** If you only
|
||||
want to measure the how long it takes to evaluate a function, you may want to
|
||||
transfer data to the device on which you want to run it first (see
|
||||
faq-data-placement_).
|
||||
|
||||
Here's an example of how to put together all these tricks into a microbenchmark
|
||||
for comparing JAX versus NumPy, making using of IPython's convenient
|
||||
`%time and %timeit magics`_::
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
|
||||
def f(x): # function we're benchmarking (works in both NumPy & JAX)
|
||||
return x.T @ (x - x.mean(axis=0))
|
||||
|
||||
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
|
||||
%timeit f(x_np) # measure NumPy runtime
|
||||
|
||||
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
|
||||
f_jit = jax.jit(f)
|
||||
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
|
||||
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
|
||||
|
||||
When run with a GPU in Colab_, we see:
|
||||
|
||||
- NumPy takes 16.2 ms per evaluation on the CPU
|
||||
- JAX takes 1.26 ms to copy the NumPy arrays onto the GPU
|
||||
- JAX takes 193 ms to compile the function
|
||||
- JAX takes 485 µs per evaluation on the GPU
|
||||
|
||||
In this case, we see that once the data is transfered and the function is
|
||||
compiled, JAX on the GPU is about 30x faster for repeated evaluations.
|
||||
|
||||
Is this a fair comparison? Maybe. The performance that ultimately matters is for
|
||||
running full applications, which inevitably include some amount of both data
|
||||
transfer and compilation. Also, we were careful to pick large enough arrays
|
||||
(1000x1000) and an intensive enough computation (the ``@`` operator is
|
||||
performing matrix-matrix multiplication) to amortize the increased overhead of
|
||||
JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
|
||||
10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).
|
||||
|
||||
.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit
|
||||
.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||
.. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time
|
||||
.. _Colab: https://colab.research.google.com/
|
||||
|
||||
.. comment We refer to the anchor below in JAX error messages
|
||||
|
||||
``Abstract tracer value encountered where concrete value is expected`` error
|
||||
|
Loading…
x
Reference in New Issue
Block a user