14 Commits

Author SHA1 Message Date
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Stephan Hoyer
5bcbce744e
Support closures in all arguments of lax.custom_root (#1570)
* WIP: linear solvers

* Draft of lax.linear_solve

* Refactor pytree munging inside lax.root.

The primitive's implementation and JVP rules are now 100% pytree free.

* Fixup linear_solve

* Linearize multiple times in _root_jvp to avoid zeros

* fix deftraced

* add a symmetric argument

* Fixup float64; add a test for symmetric/non-symmetric

* test zeros in linear_solve_jvp

* Revisions per review

* Adjust signature of linear_solve

* restore botched test

* variable names

* WIP: root solve jaxpr

* WIP more tests

* rewrite root

* Root works with jaxprs

* root -> custom_root

* WIP undefined tangent

* Delayed undefined JVP errors

* use raise_on_undefined_tangents inside define_implicit_gradient

* more tests on jvps with undefined tangents

* Remove define_implicit_gradient

* Support closures in custom_root

* revert api-test

* another test

* jit tests

* spelling
2019-10-29 16:00:00 -07:00
Peter Hawkins
c485a3cc50
Remove stale reference to lapax.py. (#1546)
Add some missing documentation references.
2019-10-21 13:47:36 -04:00
Peter Hawkins
78132c150d Document all_to_all and ppermute. 2019-10-10 15:19:17 -04:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Stephan Hoyer
8c628a267b Implement lax.map
Fixes GH-1113
2019-08-05 12:14:05 -07:00
Peter Hawkins
4eb1820ae2 Add documentation to JAX modules. 2019-07-21 15:55:47 -04:00
Matthew Johnson
2f645dd36d remove one lax parallel op from sphinx docs 2019-05-17 13:24:06 -07:00
Matthew Johnson
9b1af47a59 improve documetnation of lax parallel operators 2019-05-17 12:27:09 -07:00
Skye Wanderman-Milne
5d1c014509 Initial FFT support.
This change creates a new fft primitive in lax, and uses it to implement numpy's np.fft.fftn function.

Not-yet-implemented functionality:
- vmap
- 's' argument of fftn
- other numpy np.fft functions

Resolves #505.
2019-05-16 14:37:30 -07:00
Peter Hawkins
c2369ce340 Add lax.scan to docs. 2019-05-13 14:46:35 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Peter Hawkins
86d8915c3d Add Sphinx-generated reference documentation for JAX. 2019-01-16 09:13:31 -05:00