mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13946 from jakevdp:faq-tracer
PiperOrigin-RevId: 501073465
This commit is contained in:
commit
3e52c2d3fd
41
docs/faq.rst
41
docs/faq.rst
@ -788,17 +788,46 @@ functions, for example :func:`jax.nn.softmax` can replace uses of
|
||||
:func:`jax.numpy.sign`, :func:`jax.nn.softplus` can replace uses of
|
||||
:func:`jax.nn.relu`, etc.
|
||||
|
||||
How can I convert a JAX Tracer to a NumPy array?
|
||||
------------------------------------------------
|
||||
When inspecting a transformed JAX function at runtime, you'll find that array
|
||||
values are replaced by :class:`~jax.core.Tracer` objects::
|
||||
|
||||
Additional Sections
|
||||
-------------------
|
||||
@jax.jit
|
||||
def f(x):
|
||||
print(type(x))
|
||||
return x
|
||||
|
||||
.. comment We refer to the anchor below in JAX error messages
|
||||
f(jnp.arange(5))
|
||||
|
||||
``Abstract tracer value encountered where concrete value is expected`` error
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
See :class:`jax.errors.ConcretizationTypeError`
|
||||
This prints the following::
|
||||
|
||||
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
|
||||
|
||||
A frequent question is how such a tracer can be converted back to a normal NumPy
|
||||
array. In short, **it is impossible to convert a Tracer to a NumPy array**, because
|
||||
a tracer is an abstract representation of *every possible* value with a given shape
|
||||
and dtype, while a numpy array is a concrete member of that abstract class.
|
||||
For more discussion of how tracers work within the context of JAX transformations,
|
||||
see `JIT mechanics`_.
|
||||
|
||||
The question of converting Tracers back to arrays usually comes up within
|
||||
the context of another goal, related to accessing intermediate values in a
|
||||
computation at runtime. For example:
|
||||
|
||||
- If you wish to print a traced value at runtime for debugging purposes, you might
|
||||
consider using :func:`jax.debug.print`.
|
||||
- If you wish to call non-JAX code within a transformed JAX function, you might
|
||||
consider using :func:`jax.pure_callback`, an example of which is available at
|
||||
`Pure callback example`_.
|
||||
|
||||
For more information on runtime callbacks and examples of their use,
|
||||
see `External callbacks in JAX`_.
|
||||
|
||||
|
||||
.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
|
||||
.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp
|
||||
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
|
||||
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
|
||||
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
|
||||
|
Loading…
x
Reference in New Issue
Block a user