mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
docs: add FAQ section about jit compilation & numerics
This commit is contained in:
parent
113cd9b939
commit
b889282f6d
30
docs/faq.rst
30
docs/faq.rst
@ -51,6 +51,36 @@ Additional reading:
|
||||
|
||||
* `JAX - The Sharp Bits`_
|
||||
|
||||
.. _faq-jit-numerics:
|
||||
|
||||
``jit`` changes the exact numerics of outputs
|
||||
---------------------------------------------
|
||||
Sometimes users are surprised by the fact that wrapping a function with `jit` can
|
||||
make its outputs slightly different. For example:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> def f(x, y):
|
||||
... return x + y - x
|
||||
>>> x = jnp.array(1.0)
|
||||
>>> y = jnp.array(0.001)
|
||||
>>> print(f(x, y))
|
||||
0.0010000467
|
||||
|
||||
>>> print(jit(f)(x, y))
|
||||
0.001
|
||||
|
||||
This happens because of optimizations within the XLA compiler. During compilation,
|
||||
XLA will often re-arrange floating point operations to simplify the expression it
|
||||
computes. For example, consider the expression ``x + y - x`` above. In non-JIT
|
||||
op-by-op evaluation, this addition and subtraction both accumulate standard
|
||||
32-bit floating point arithmetic error, so the result is not exactly equal ``y``.
|
||||
By contrast, in JIT the XLA compiler recognizes that the ``x`` and ``-x`` cancel
|
||||
each other, and so it drops these terms and the return value is identically equal
|
||||
to ``y``.
|
||||
|
||||
In general, for this and other related reasons, it is to be expected that JIT-compiled
|
||||
code will produce slightly different outputs than its non-JIT compiled counterpart.
|
||||
|
||||
.. _faq-slow-compile:
|
||||
|
||||
``jit`` decorated function is very slow to compile
|
||||
|
Loading…
x
Reference in New Issue
Block a user