Merge pull request #12991 from jakevdp:fix-faq

PiperOrigin-RevId: 484042682
This commit is contained in:
jax authors 2022-10-26 12:39:00 -07:00
commit a08ced86f3

View File

@ -55,31 +55,52 @@ Additional reading:
``jit`` changes the exact numerics of outputs
---------------------------------------------
Sometimes users are surprised by the fact that wrapping a function with `jit` can
make its outputs slightly different. For example:
Sometimes users are surprised by the fact that wrapping a function with :func:`jit`
can change the function's outputs. For example:
>>> from jax import jit
>>> def f(x, y):
... return x + y - x
>>> x = jnp.array(1.0)
>>> y = jnp.array(0.001)
>>> print(f(x, y))
0.0010000467
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x, y))
0.001
>>> print(jit(f)(x))
0.5723649
This happens because of optimizations within the XLA compiler. During compilation,
XLA will often re-arrange floating point operations to simplify the expression it
computes. For example, consider the expression ``x + y - x`` above. In non-JIT
op-by-op evaluation, this addition and subtraction both accumulate standard
32-bit floating point arithmetic error, so the result is not exactly equal ``y``.
By contrast, in JIT the XLA compiler recognizes that the ``x`` and ``-x`` cancel
each other, and so it drops these terms and the return value is identically equal
to ``y``.
This slight difference in output comes from optimizations within the XLA compiler:
during compilation, XLA will sometimes rearrange or elide certain operations to make
the overall computation more efficient.
In general, for this and other related reasons, it is to be expected that JIT-compiled
code will produce slightly different outputs than its non-JIT compiled counterpart.
In this case, XLA utilizes the properties of the logarithm to replace ``log(sqrt(x))``
with ``0.5 * log(x)``, which is a mathematically identical expression that can be
computed more efficiently than the original. The difference in output comes from
the fact that floating point arithmetic is only a close approximation of real math,
so different ways of computing the same expression may have subtly different results.
Other times, XLA's optimizations may lead to even more drastic differences.
Consider the following example:
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
In non-JIT-compiled op-by-op mode, the result is ``inf`` because ``jnp.exp(x)``
overflows and returns ``inf``. Under JIT, however, XLA recognizes that ``log`` is
the inverse of ``exp``, and removes the operations from the compiled function,
simply returning the input. In this case, JIT compilation produces a more accurate
floating point approximation of the real result.
Unfortunately the full list of XLA's algebraic simplifications is not well
documented, but if you're familiar with C++ and curious about what types of
optimizations the XLA compiler makes, you can see them in the source code:
`algebraic_simplifier.cc`_.
.. _faq-slow-compile:
@ -785,3 +806,4 @@ See :class:`jax.errors.ConcretizationTypeError`
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266