mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #12991 from jakevdp:fix-faq
PiperOrigin-RevId: 484042682
This commit is contained in:
commit
a08ced86f3
62
docs/faq.rst
62
docs/faq.rst
@ -55,31 +55,52 @@ Additional reading:
|
||||
|
||||
``jit`` changes the exact numerics of outputs
|
||||
---------------------------------------------
|
||||
Sometimes users are surprised by the fact that wrapping a function with `jit` can
|
||||
make its outputs slightly different. For example:
|
||||
Sometimes users are surprised by the fact that wrapping a function with :func:`jit`
|
||||
can change the function's outputs. For example:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> def f(x, y):
|
||||
... return x + y - x
|
||||
>>> x = jnp.array(1.0)
|
||||
>>> y = jnp.array(0.001)
|
||||
>>> print(f(x, y))
|
||||
0.0010000467
|
||||
>>> import jax.numpy as jnp
|
||||
>>> def f(x):
|
||||
... return jnp.log(jnp.sqrt(x))
|
||||
>>> x = jnp.pi
|
||||
>>> print(f(x))
|
||||
0.572365
|
||||
|
||||
>>> print(jit(f)(x, y))
|
||||
0.001
|
||||
>>> print(jit(f)(x))
|
||||
0.5723649
|
||||
|
||||
This happens because of optimizations within the XLA compiler. During compilation,
|
||||
XLA will often re-arrange floating point operations to simplify the expression it
|
||||
computes. For example, consider the expression ``x + y - x`` above. In non-JIT
|
||||
op-by-op evaluation, this addition and subtraction both accumulate standard
|
||||
32-bit floating point arithmetic error, so the result is not exactly equal ``y``.
|
||||
By contrast, in JIT the XLA compiler recognizes that the ``x`` and ``-x`` cancel
|
||||
each other, and so it drops these terms and the return value is identically equal
|
||||
to ``y``.
|
||||
This slight difference in output comes from optimizations within the XLA compiler:
|
||||
during compilation, XLA will sometimes rearrange or elide certain operations to make
|
||||
the overall computation more efficient.
|
||||
|
||||
In general, for this and other related reasons, it is to be expected that JIT-compiled
|
||||
code will produce slightly different outputs than its non-JIT compiled counterpart.
|
||||
In this case, XLA utilizes the properties of the logarithm to replace ``log(sqrt(x))``
|
||||
with ``0.5 * log(x)``, which is a mathematically identical expression that can be
|
||||
computed more efficiently than the original. The difference in output comes from
|
||||
the fact that floating point arithmetic is only a close approximation of real math,
|
||||
so different ways of computing the same expression may have subtly different results.
|
||||
|
||||
Other times, XLA's optimizations may lead to even more drastic differences.
|
||||
Consider the following example:
|
||||
|
||||
>>> def f(x):
|
||||
... return jnp.log(jnp.exp(x))
|
||||
>>> x = 100.0
|
||||
>>> print(f(x))
|
||||
inf
|
||||
|
||||
>>> print(jit(f)(x))
|
||||
100.0
|
||||
|
||||
In non-JIT-compiled op-by-op mode, the result is ``inf`` because ``jnp.exp(x)``
|
||||
overflows and returns ``inf``. Under JIT, however, XLA recognizes that ``log`` is
|
||||
the inverse of ``exp``, and removes the operations from the compiled function,
|
||||
simply returning the input. In this case, JIT compilation produces a more accurate
|
||||
floating point approximation of the real result.
|
||||
|
||||
Unfortunately the full list of XLA's algebraic simplifications is not well
|
||||
documented, but if you're familiar with C++ and curious about what types of
|
||||
optimizations the XLA compiler makes, you can see them in the source code:
|
||||
`algebraic_simplifier.cc`_.
|
||||
|
||||
.. _faq-slow-compile:
|
||||
|
||||
@ -785,3 +806,4 @@ See :class:`jax.errors.ConcretizationTypeError`
|
||||
|
||||
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
|
||||
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
|
||||
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
|
||||
|
Loading…
x
Reference in New Issue
Block a user