2020-03-19 14:55:16 +01:00
|
|
|
JAX Frequently Asked Questions
|
|
|
|
==============================
|
|
|
|
|
2020-04-20 11:30:52 +03:00
|
|
|
.. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html
|
|
|
|
.. comment Some links referenced here. Use JAX_sharp_bits_ (underscore at the end) to reference
|
|
|
|
|
|
|
|
|
|
|
|
.. _JAX_sharp_bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
|
|
|
|
.. _How_JAX_primitives_work: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
|
|
|
|
|
2020-03-19 14:55:16 +01:00
|
|
|
We are collecting here answers to frequently asked questions.
|
|
|
|
Contributions welcome!
|
|
|
|
|
2020-03-24 10:22:49 +01:00
|
|
|
Creating arrays with `jax.numpy.array` is slower than with `numpy.array`
|
|
|
|
------------------------------------------------------------------------
|
|
|
|
|
|
|
|
The following code is relatively fast when using NumPy, and slow when using
|
|
|
|
JAX's NumPy::
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
np.array([0] * int(1e6))
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
The reason is that in NumPy the ``numpy.array`` function is implemented in C, while
|
|
|
|
the :func:`jax.numpy.array` is implemented in Python, and it needs to iterate over a long
|
2020-03-24 10:22:49 +01:00
|
|
|
list to convert each list element to an array element.
|
|
|
|
|
|
|
|
An alternative would be to create the array with original NumPy and then convert
|
|
|
|
it to a JAX array::
|
|
|
|
|
|
|
|
from jax import numpy as jnp
|
|
|
|
jnp.array(np.array([0] * int(1e6)))
|
|
|
|
|
2020-03-22 06:47:14 +01:00
|
|
|
`jit` changes the behavior of my function
|
|
|
|
-----------------------------------------
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
If you have a Python function that changes behavior after using :func:`jax.jit`, perhaps
|
2020-03-22 06:47:14 +01:00
|
|
|
your function uses global state, or has side-effects. In the following code, the
|
2020-04-19 12:13:07 +03:00
|
|
|
``impure_func`` uses the global ``y`` and has a side-effect due to ``print``::
|
2020-03-22 06:47:14 +01:00
|
|
|
|
|
|
|
y = 0
|
|
|
|
|
|
|
|
# @jit # Different behavior with jit
|
|
|
|
def impure_func(x):
|
|
|
|
print("Inside:", y)
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
for y in range(3):
|
|
|
|
print("Result:", impure_func(y))
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
Without ``jit`` the output is::
|
2020-03-22 06:47:14 +01:00
|
|
|
|
|
|
|
Inside: 0
|
|
|
|
Result: 0
|
|
|
|
Inside: 1
|
|
|
|
Result: 2
|
|
|
|
Inside: 2
|
|
|
|
Result: 4
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
and with ``jit`` it is::
|
2020-03-22 06:47:14 +01:00
|
|
|
|
|
|
|
Inside: 0
|
|
|
|
Result: 0
|
|
|
|
Result: 1
|
|
|
|
Result: 2
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
For :func:`jax.jit`, the function is executed once using the Python interpreter, at which time the
|
|
|
|
``Inside`` printing happens, and the first value of ``y`` is observed. Then the function
|
|
|
|
is compiled and cached, and executed multiple times with different values of ``x``, but
|
|
|
|
with the same first value of ``y``.
|
2020-03-22 06:47:14 +01:00
|
|
|
|
|
|
|
Additional reading:
|
|
|
|
|
2020-04-20 11:30:52 +03:00
|
|
|
* JAX_sharp_bits_
|
2020-04-19 12:13:07 +03:00
|
|
|
|
2020-04-22 10:25:06 +03:00
|
|
|
.. comment We refer to the anchor below in JAX error messages
|
|
|
|
|
|
|
|
`Abstract tracer value encountered where concrete value is expected` error
|
|
|
|
--------------------------------------------------------------------------
|
|
|
|
|
|
|
|
If you are getting an error that a library function is called with
|
|
|
|
*"Abstract tracer value encountered where concrete value is expected"*, you may need to
|
|
|
|
change how you invoke JAX transformations. We give first an example, and
|
|
|
|
a couple of solutions, and then we explain in more detail what is actually
|
|
|
|
happening, if you are curious or the simple solution does not work for you.
|
|
|
|
|
|
|
|
Some library functions take arguments that specify shapes or axes,
|
|
|
|
such as the 2nd and 3rd arguments for :func:`jax.numpy.split`::
|
|
|
|
|
|
|
|
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
|
|
|
|
np.split(np.zeros(2), 2, 0) # works
|
|
|
|
|
|
|
|
If you try the following code::
|
|
|
|
|
|
|
|
jax.jit(np.split)(np.zeros(4), 2, 0)
|
|
|
|
|
|
|
|
you will get the following error::
|
|
|
|
|
|
|
|
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
|
|
|
|
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
|
|
|
|
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
|
|
|
|
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
|
|
|
|
|
|
|
|
We must change the way we use :func:`jax.jit` to ensure that the ``num_sections``
|
|
|
|
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively).
|
|
|
|
The best mechanism is to use special transformation parameters
|
|
|
|
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`::
|
|
|
|
|
|
|
|
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
|
|
|
|
|
|
|
|
An alternative is to apply the transformation to a closure
|
|
|
|
that encapsulates the arguments to be protected, either manually as below
|
|
|
|
or by using ``functools.partial``::
|
|
|
|
|
|
|
|
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
|
|
|
|
|
|
|
|
**Note a new closure is created at every invocation, which defeats the
|
|
|
|
compilation caching mechanism, which is why static_argnums is preferred.**
|
|
|
|
|
|
|
|
To understand more subtleties having to do with tracers vs. regular values, and
|
|
|
|
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_.
|
|
|
|
|
|
|
|
Different kinds of JAX values
|
|
|
|
------------------------------
|
|
|
|
|
|
|
|
In the process of transforming functions, JAX replaces some some function
|
|
|
|
arguments with special tracer values.
|
|
|
|
You could see this if you use a ``print`` statement::
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
print(x)
|
|
|
|
return np.cos(x)
|
|
|
|
|
|
|
|
res = jax.jit(func)(0.)
|
|
|
|
|
|
|
|
The above code does return the correct value ``1.`` but it also prints
|
|
|
|
``Traced<ShapedArray(float32[])>`` for the value of ``x``. Normally, JAX
|
|
|
|
handles these tracer values internally in a transparent way, e.g.,
|
|
|
|
in the numeric JAX primitives that are used to implement the
|
|
|
|
``jax.numpy`` functions. This is why ``np.cos`` works in the example above.
|
|
|
|
|
|
|
|
More precisely, a **tracer** value is introduced for the argument of
|
|
|
|
a JAX-transformed function, except the arguments identified by special
|
|
|
|
parameters such as ``static_argnums`` for :func:`jax.jit` or
|
|
|
|
``static_broadcasted_argnums`` for :func:`jax.pmap`. Typically, computations
|
|
|
|
that involve at least a tracer value will produce a tracer value. Besides tracer
|
|
|
|
values, there are **regular** Python values: values that are computed outside JAX
|
|
|
|
transformations, or arise from above-mentioned static arguments of certain JAX
|
|
|
|
transformations, or computed solely from other regular Python values.
|
|
|
|
These are the values that are used everywhere in absence of JAX transformations.
|
|
|
|
|
|
|
|
A tracer value carries an **abstract** value, e.g., ``ShapedArray`` with information
|
|
|
|
about the shape and dtype of an array. We will refer here to such tracers as
|
|
|
|
**abstract tracers**. Some tracers, e.g., those that are
|
|
|
|
introduced for arguments of autodiff transformations, carry ``ConcreteArray``
|
|
|
|
abstract values that actually include the regular array data, and are used,
|
|
|
|
e.g., for resolving conditionals. We will refer here to such tracers
|
|
|
|
as **concrete tracers**. Tracer values computed from these concrete tracers,
|
|
|
|
perhaps in combination with regular values, result in concrete tracers.
|
|
|
|
A **concrete value** is either a regular value or a concrete tracer.
|
|
|
|
|
|
|
|
Most often values computed from tracer values are themselves tracer values.
|
|
|
|
There are very few exceptions, when a computation can be entirely done
|
|
|
|
using the abstract value carried by a tracer, in which case the result
|
|
|
|
can be a regular value. For example, getting the shape of a tracer
|
|
|
|
with ``ShapedArray`` abstract value. Another example, is when explicitly
|
|
|
|
casting a concrete tracer value to a regular type, e.g., ``int(x)`` or
|
|
|
|
``x.astype(float)``.
|
|
|
|
Another such situation is for ``bool(x)``, which produces a Python bool when
|
|
|
|
concreteness makes it possible. That case is especially salient because
|
|
|
|
of how often it arises in control flow.
|
|
|
|
|
|
|
|
Here is how the transformations introduce abstract or concrete tracers:
|
|
|
|
|
|
|
|
* :func:`jax.jit`: introduces **abstract tracers** for all positional arguments
|
|
|
|
except those denoted by ``static_argnums``, which remain regular
|
|
|
|
values.
|
|
|
|
* :func:`jax.pmap`: introduces **abstract tracers** for all positional arguments
|
|
|
|
except those denoted by ``static_broadcasted_argnums``.
|
|
|
|
* :func:`jax.vmap`, :func:`jax.make_jaxpr`, :func:`xla_computation`:
|
|
|
|
introduce **abstract tracers** for all positional arguments.
|
|
|
|
* :func:`jax.jvp` and :func:`jax.grad` introduce **concrete tracers**
|
|
|
|
for all positional arguments. An exception is when these transformations
|
|
|
|
are within an outer transformation and the actual arguments are
|
|
|
|
themselves abstract tracers; in that case, the tracers introduced
|
|
|
|
by the autodiff transformations are also abstract tracers.
|
|
|
|
* All higher-order control-flow primitives (:func:`lax.cond`, :func:`lax.while_loop`,
|
|
|
|
:func:`lax.fori_loop`, :func:`lax.scan`) when they process the functionals
|
|
|
|
introduce **abstract tracers**, whether or not there is a JAX transformation
|
|
|
|
in progress.
|
|
|
|
|
|
|
|
All of this is relevant when you have code that can operate
|
|
|
|
only on regular Python values, such as code that has conditional
|
|
|
|
control-flow based on data::
|
|
|
|
|
|
|
|
def divide(x, y):
|
|
|
|
return x / y if y >= 1. else 0.
|
|
|
|
|
|
|
|
If we want to apply :func:`jax.jit`, we must ensure to specify ``static_argnums=1``
|
|
|
|
to ensure ``y`` stays a regular value. This is due to the boolean expression
|
|
|
|
``y >= 1.``, which requires concrete values (regular or tracers). The
|
|
|
|
same would happen if we write explicitly ``bool(y >= 1.)``, or ``int(y)``,
|
|
|
|
or ``float(y)``.
|
|
|
|
|
|
|
|
Interestingly, ``jax.grad(divide)(3., 2.)``, works because :func:`jax.grad`
|
|
|
|
uses concrete tracers, and resolves the conditional using the concrete
|
|
|
|
value of ``y``.
|
2020-03-22 06:47:14 +01:00
|
|
|
|
2020-03-19 14:55:16 +01:00
|
|
|
Gradients contain `NaN` where using ``where``
|
|
|
|
------------------------------------------------
|
|
|
|
|
|
|
|
If you define a function using ``where`` to avoid an undefined value, if you
|
2020-04-19 12:13:07 +03:00
|
|
|
are not careful you may obtain a ``NaN`` for reverse differentiation::
|
2020-03-19 14:55:16 +01:00
|
|
|
|
|
|
|
def my_log(x):
|
|
|
|
return np.where(x > 0., np.log(x), 0.)
|
|
|
|
|
|
|
|
my_log(0.) ==> 0. # Ok
|
|
|
|
jax.grad(my_log)(0.) ==> NaN
|
|
|
|
|
|
|
|
A short explanation is that during ``grad`` computation the adjoint corresponding
|
|
|
|
to the undefined ``np.log(x)`` is a ``NaN`` and when it gets accumulated to the
|
|
|
|
adjoint of the ``np.where``. The correct way to write such functions is to ensure
|
|
|
|
that there is a ``np.where`` *inside* the partially-defined function, to ensure
|
|
|
|
that the adjoint is always finite::
|
|
|
|
|
|
|
|
def safe_for_grad_log(x):
|
|
|
|
return np.log(np.where(x > 0., x, 1.)
|
|
|
|
|
|
|
|
safe_for_grad_log(0.) ==> 0. # Ok
|
|
|
|
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
The inner ``np.where`` may be needed in addition to the original one, e.g.::
|
2020-03-19 14:55:16 +01:00
|
|
|
|
|
|
|
def my_log_or_y(x, y):
|
|
|
|
"""Return log(x) if x > 0 or y"""
|
|
|
|
return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)
|
|
|
|
|
|
|
|
|
|
|
|
Additional reading:
|
|
|
|
|
2020-04-19 12:13:07 +03:00
|
|
|
* `Issue: gradients through np.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_.
|
2020-04-20 11:30:52 +03:00
|
|
|
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.
|
|
|
|
|
|
|
|
Why do I get forward-mode differentiation error when I am trying to do reverse-mode differentiation?
|
|
|
|
-----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
JAX implements reverse-mode differentiation as a composition of two operations:
|
|
|
|
linearization and transposition. The linearization step (see :func:`jax.linearize`)
|
|
|
|
uses the JVP rules to form the forward-computation of tangents along with the intermediate
|
|
|
|
forward computations of intermediate values on which the tangents depend.
|
|
|
|
The transposition step will turn the forward-computation of tangents
|
|
|
|
into a reverse-mode computation.
|
|
|
|
|
|
|
|
If the JVP rule is not implemented for a primitive, then neither the forward-mode
|
|
|
|
nor the reverse-mode differentiation will work, but the error given will refer
|
|
|
|
to the forward-mode because that is the one that fails.
|
|
|
|
|
|
|
|
You can read more details at How_JAX_primitives_work_.
|
|
|
|
|