We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
* [doc] Note that building jaxlib from source isn't always necessary
Building jaxlib from source is time-consuming and the source of most
pain for building JAX. It's also not necessary (in my experience) for
pure-Python changes.
This commit adds notes to the 'building from source' documentation to
make this explicit.
* Move ``jaxlib`` skip instructions to top
Also add a new section for the np.fft module. These functions were previously
not appearing in the docs, because fftn is not exposed as np.fftn but only as
np.fft.fftn.
* WIP: linear solvers
* Draft of lax.linear_solve
* Refactor pytree munging inside lax.root.
The primitive's implementation and JVP rules are now 100% pytree free.
* Fixup linear_solve
* Linearize multiple times in _root_jvp to avoid zeros
* fix deftraced
* add a symmetric argument
* Fixup float64; add a test for symmetric/non-symmetric
* test zeros in linear_solve_jvp
* Revisions per review
* Adjust signature of linear_solve
* restore botched test
* variable names
* WIP: root solve jaxpr
* WIP more tests
* rewrite root
* Root works with jaxprs
* root -> custom_root
* WIP undefined tangent
* Delayed undefined JVP errors
* use raise_on_undefined_tangents inside define_implicit_gradient
* more tests on jvps with undefined tangents
* Remove define_implicit_gradient
* Support closures in custom_root
* revert api-test
* another test
* jit tests
* spelling
* Moved all notebooks to docs/notebooks.
Now all notebooks are in the same place, thus all are subject
to auto-doc generation at readthedocs.io and to automated testing
with travis.
Some notebooks are too slow, exclude them at docs/conf.py:exclude_patterns.
Cleanup a bit the section headings in notebooks so that they show
up well in readtehdocs.io.
* Increase the cell timeout for executing notebooks
* Exclude also the neural network notebook from auto-generation (timing out)
* Disable the score_matching notebook from auto-doc (travis does not have sklearn)
Had to extend the docs/requirements.txt file to install
matplotlb (needed by the Gotchas notebook) and ".",
needed by everything. This results in a reduction
of the sphinx warnings from 3300 to 1200!
Testing is done by running "jupyter nbconvert --to notebook" and
then parsing the resulting notebook to look for errors.
One can declare expected errors, and the test will fail if those
are missing.
In the process of doig this, found and fixed a bug in the autodiff_cookbook
notebook.
* Cleaned up use of section levels
* Renamed ma to multiply_add and sq_add to square_add
* Other minor clarifications
* Separated the Colabs into Tutorials and Advanced Tutorials