docs: add FAQ section about jit compilation & numerics

This commit is contained in:
Jake VanderPlas 2021-12-16 15:17:34 -08:00
parent 113cd9b939
commit b889282f6d

View File

@ -51,6 +51,36 @@ Additional reading:
* `JAX - The Sharp Bits`_
.. _faq-jit-numerics:
``jit`` changes the exact numerics of outputs
---------------------------------------------
Sometimes users are surprised by the fact that wrapping a function with `jit` can
make its outputs slightly different. For example:
>>> from jax import jit
>>> def f(x, y):
... return x + y - x
>>> x = jnp.array(1.0)
>>> y = jnp.array(0.001)
>>> print(f(x, y))
0.0010000467
>>> print(jit(f)(x, y))
0.001
This happens because of optimizations within the XLA compiler. During compilation,
XLA will often re-arrange floating point operations to simplify the expression it
computes. For example, consider the expression ``x + y - x`` above. In non-JIT
op-by-op evaluation, this addition and subtraction both accumulate standard
32-bit floating point arithmetic error, so the result is not exactly equal ``y``.
By contrast, in JIT the XLA compiler recognizes that the ``x`` and ``-x`` cancel
each other, and so it drops these terms and the return value is identically equal
to ``y``.
In general, for this and other related reasons, it is to be expected that JIT-compiled
code will produce slightly different outputs than its non-JIT compiled counterpart.
.. _faq-slow-compile:
``jit`` decorated function is very slow to compile