removed stale faq entries (#3565)

This commit is contained in:
Matthew Johnson 2020-06-25 19:17:24 -07:00 committed by GitHub
parent c2501d1bef
commit 062ce297dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
JAX Frequently Asked Questions
==============================
JAX Frequently Asked Questions (FAQ)
====================================
.. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html
.. comment Some links referenced here. Use JAX_sharp_bits_ (underscore at the end) to reference
@ -11,25 +11,6 @@ JAX Frequently Asked Questions
We are collecting here answers to frequently asked questions.
Contributions welcome!
Creating arrays with `jax.numpy.array` is slower than with `numpy.array`
------------------------------------------------------------------------
The following code is relatively fast when using NumPy, and slow when using
JAX's NumPy::
import numpy as np
np.array([0] * int(1e6))
The reason is that in NumPy the ``numpy.array`` function is implemented in C, while
the :func:`jax.numpy.array` is implemented in Python, and it needs to iterate over a long
list to convert each list element to an array element.
An alternative would be to create the array with original NumPy and then convert
it to a JAX array::
from jax import numpy as jnp
jnp.array(np.array([0] * int(1e6)))
`jit` changes the behavior of my function
-----------------------------------------
@ -293,20 +274,3 @@ Additional reading:
* `Issue: gradients through np.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_.
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.
Why do I get forward-mode differentiation error when I am trying to do reverse-mode differentiation?
-----------------------------------------------------------------------------------------------------
JAX implements reverse-mode differentiation as a composition of two operations:
linearization and transposition. The linearization step (see :func:`jax.linearize`)
uses the JVP rules to form the forward-computation of tangents along with the intermediate
forward computations of intermediate values on which the tangents depend.
The transposition step will turn the forward-computation of tangents
into a reverse-mode computation.
If the JVP rule is not implemented for a primitive, then neither the forward-mode
nor the reverse-mode differentiation will work, but the error given will refer
to the forward-mode because that is the one that fails.
You can read more details at How_JAX_primitives_work_.