Merge pull request #13946 from jakevdp:faq-tracer

PiperOrigin-RevId: 501073465
This commit is contained in:
jax authors 2023-01-10 13:07:56 -08:00
commit 3e52c2d3fd

View File

@ -788,17 +788,46 @@ functions, for example :func:`jax.nn.softmax` can replace uses of
:func:`jax.numpy.sign`, :func:`jax.nn.softplus` can replace uses of
:func:`jax.nn.relu`, etc.
How can I convert a JAX Tracer to a NumPy array?
------------------------------------------------
When inspecting a transformed JAX function at runtime, you'll find that array
values are replaced by :class:`~jax.core.Tracer` objects::
Additional Sections
-------------------
@jax.jit
def f(x):
print(type(x))
return x
.. comment We refer to the anchor below in JAX error messages
f(jnp.arange(5))
``Abstract tracer value encountered where concrete value is expected`` error
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :class:`jax.errors.ConcretizationTypeError`
This prints the following::
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
A frequent question is how such a tracer can be converted back to a normal NumPy
array. In short, **it is impossible to convert a Tracer to a NumPy array**, because
a tracer is an abstract representation of *every possible* value with a given shape
and dtype, while a numpy array is a concrete member of that abstract class.
For more discussion of how tracers work within the context of JAX transformations,
see `JIT mechanics`_.
The question of converting Tracers back to arrays usually comes up within
the context of another goal, related to accessing intermediate values in a
computation at runtime. For example:
- If you wish to print a traced value at runtime for debugging purposes, you might
consider using :func:`jax.debug.print`.
- If you wish to call non-JAX code within a transformed JAX function, you might
consider using :func:`jax.pure_callback`, an example of which is available at
`Pure callback example`_.
For more information on runtime callbacks and examples of their use,
see `External callbacks in JAX`_.
.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266