mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
removed stale faq entries (#3565)
This commit is contained in:
parent
c2501d1bef
commit
062ce297dd
40
docs/faq.rst
40
docs/faq.rst
@ -1,5 +1,5 @@
|
||||
JAX Frequently Asked Questions
|
||||
==============================
|
||||
JAX Frequently Asked Questions (FAQ)
|
||||
====================================
|
||||
|
||||
.. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html
|
||||
.. comment Some links referenced here. Use JAX_sharp_bits_ (underscore at the end) to reference
|
||||
@ -11,25 +11,6 @@ JAX Frequently Asked Questions
|
||||
We are collecting here answers to frequently asked questions.
|
||||
Contributions welcome!
|
||||
|
||||
Creating arrays with `jax.numpy.array` is slower than with `numpy.array`
|
||||
------------------------------------------------------------------------
|
||||
|
||||
The following code is relatively fast when using NumPy, and slow when using
|
||||
JAX's NumPy::
|
||||
|
||||
import numpy as np
|
||||
np.array([0] * int(1e6))
|
||||
|
||||
The reason is that in NumPy the ``numpy.array`` function is implemented in C, while
|
||||
the :func:`jax.numpy.array` is implemented in Python, and it needs to iterate over a long
|
||||
list to convert each list element to an array element.
|
||||
|
||||
An alternative would be to create the array with original NumPy and then convert
|
||||
it to a JAX array::
|
||||
|
||||
from jax import numpy as jnp
|
||||
jnp.array(np.array([0] * int(1e6)))
|
||||
|
||||
`jit` changes the behavior of my function
|
||||
-----------------------------------------
|
||||
|
||||
@ -293,20 +274,3 @@ Additional reading:
|
||||
|
||||
* `Issue: gradients through np.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_.
|
||||
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.
|
||||
|
||||
Why do I get forward-mode differentiation error when I am trying to do reverse-mode differentiation?
|
||||
-----------------------------------------------------------------------------------------------------
|
||||
|
||||
JAX implements reverse-mode differentiation as a composition of two operations:
|
||||
linearization and transposition. The linearization step (see :func:`jax.linearize`)
|
||||
uses the JVP rules to form the forward-computation of tangents along with the intermediate
|
||||
forward computations of intermediate values on which the tangents depend.
|
||||
The transposition step will turn the forward-computation of tangents
|
||||
into a reverse-mode computation.
|
||||
|
||||
If the JVP rule is not implemented for a primitive, then neither the forward-mode
|
||||
nor the reverse-mode differentiation will work, but the error given will refer
|
||||
to the forward-mode because that is the one that fails.
|
||||
|
||||
You can read more details at How_JAX_primitives_work_.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user