2140 Commits

Author SHA1 Message Date
Archis Joglekar
ca15512932 added fft2 and ifft2, corresponding tests, and documentation links. (#1939) 2020-01-04 18:21:30 -08:00
Archis Joglekar
d9c6a5f4a8 fft and ifft implementation (#1926)
* first commit with placeholders for tests

* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis

added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds

* removed placeholders
added functions to .rst file
2020-01-02 17:35:22 -08:00
John Mellor
fd6067471e Fix minor typo in Common_Gotchas_in_JAX.ipynb
Moved misplaced backtick
2020-01-02 07:43:59 -08:00
Yo
322ebe7c9b Update docs/conf.py 2019-12-30 11:27:12 -08:00
flowed
e0693fe649 Fix Typos 2019-12-30 11:27:12 -08:00
Matthew Johnson
f5723848d3
fix error in autodiff cookbook: 3x not 2x 2019-12-30 07:36:36 -08:00
David Bieber
30bede1f6a fix typo in autodiff cookbook (#1921) 2019-12-27 11:02:06 -08:00
Peter Hawkins
698babf9ec
Implement jax.numpy.nonzero and 1-argument jax.numpy.where. (#1905)
* Implement jax.numpy.nonzero.

* Implement the one-argument form of np.where.

* Fix output type and error message.

* Add lax_description strings to where and nonzero.
2019-12-20 18:42:33 -05:00
Matthew Johnson
8dad859e04 streamline readme, add pmap 2019-12-14 10:32:04 -08:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Peter Hawkins
e87d9718c3
Support IntEnum values as arguments to JAX functions. (#1840)
* Support IntEnum values as arguments to JAX functions.

When abstractifying a Python value, search the method-resolution order (MRO) of the type rather than only looking at the value's own type. IntEnum instances are subclasses of int, so this allows us to correctly handle them as integers, much as NumPy itself does.
2019-12-11 12:27:11 -05:00
tamaranorman
26e863923a Support atrous conv in same padded convolution and add warning if use transposed convolution with same or valid padding. (#1806)
PiperOrigin-RevId: 283517237
2019-12-09 08:06:59 -08:00
Peter Hawkins
d958f3007d
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.

NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.

This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,

```
import numpy as onp
from jax import numpy as np

In [1]: onp.promote_types(onp.float32, onp.int32)   
Out[1]: dtype('float64')

In [2]: onp.promote_types(onp.float16, onp.int64)   
Out[2]: dtype('float64')

In [3]: np.promote_types(onp.float32, onp.int32)    
Out[3]: dtype('float32')

In [4]: np.promote_types(onp.float16, onp.int64)    
Out[4]: dtype('float16')
```

This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
Peter Hawkins
f0d9333379
Document functions in jax.nn. (#1795) 2019-12-02 14:21:10 -05:00
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
b0ffbaf1f6 Fixed also a notebook that has gone stale 2019-11-27 07:26:46 +01:00
George Necula
8777864c96 Minor edits 2019-11-24 20:29:44 +01:00
George Necula
b12a8019c8 Update docs/notebooks/JAX_pytrees.ipynb
Co-Authored-By: Stephan Hoyer <shoyer@google.com>
2019-11-24 20:29:29 +01:00
George Necula
4e89d43a75 Added JAX pytrees notebook
Also added docstrings to the tree_util module.
2019-11-24 20:29:07 +01:00
Skye Wanderman-Milne
6f3cb1c3ee Add jax.devices(), etc. to the docs. 2019-11-22 11:03:42 -08:00
Peter Hawkins
c60f3fd65d
Minor documentation fixes. (#1734) 2019-11-21 09:51:26 -05:00
Stephan Hoyer
ee29705712
Add jax.scipy.ndimage to online docs (#1724) 2019-11-20 12:35:10 -08:00
Peter Buchlovsky
9d1204689f Fix typo 2019-11-20 08:53:01 -08:00
Peter Buchlovsky
410ebfeb1c Fix typo 2019-11-20 08:52:46 -08:00
George Necula
397a244e7f
Merge pull request #1706 from gnecula/loops
An implementation of an experimental syntactic sugar for 'for' and `while` loops and conditionals.
2019-11-18 12:17:59 +01:00
Anselm Levskaya
f882359511
fix lax.scan notes in gotchas notebook
Note that lax.scan is now jittable and differentiable in the Gotchas notebook.
2019-11-17 00:19:24 -08:00
George Necula
d24c374d59 An implementation of an experimental syntactic sugar for 'for' loops.
See description in jax/experimental/loops.py.
2019-11-16 17:23:40 +01:00
George Necula
c6d3270512 Fixed tests for X64 2019-11-14 12:54:30 +01:00
Trevor Cai
340d82e93e [doc] Note that building jaxlib from source isn't always necessary (#1654)
* [doc] Note that building jaxlib from source isn't always necessary

Building jaxlib from source is time-consuming and the source of most
pain for building JAX. It's also not necessary (in my experience) for
pure-Python changes.

This commit adds notes to the 'building from source' documentation to
make this explicit.

* Move ``jaxlib`` skip instructions to top
2019-11-13 13:25:39 -05:00
Matthew Johnson
7a9ea8a006
Merge pull request #1582 from sharadmv/custom-interpreter
Add custom interpreter notebook
2019-11-08 13:13:24 -08:00
Sharad Vikram
7bc2b0878a Update description of eqn.parmas 2019-11-08 13:11:17 -08:00
Sharad Vikram
1f40c9c4d2 Fix writing suggestions from mattjj 2019-11-03 15:54:05 -08:00
Stephan Hoyer
e6ad9c29da
Docstring fixss for lax.custom_linear_solve (#1616)
Also add a new section for the np.fft module. These functions were previously
not appearing in the docs, because fftn is not exposed as np.fftn but only as
np.fft.fftn.
2019-11-01 09:04:44 -07:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
Stephan Hoyer
5bcbce744e
Support closures in all arguments of lax.custom_root (#1570)
* WIP: linear solvers

* Draft of lax.linear_solve

* Refactor pytree munging inside lax.root.

The primitive's implementation and JVP rules are now 100% pytree free.

* Fixup linear_solve

* Linearize multiple times in _root_jvp to avoid zeros

* fix deftraced

* add a symmetric argument

* Fixup float64; add a test for symmetric/non-symmetric

* test zeros in linear_solve_jvp

* Revisions per review

* Adjust signature of linear_solve

* restore botched test

* variable names

* WIP: root solve jaxpr

* WIP more tests

* rewrite root

* Root works with jaxprs

* root -> custom_root

* WIP undefined tangent

* Delayed undefined JVP errors

* use raise_on_undefined_tangents inside define_implicit_gradient

* more tests on jvps with undefined tangents

* Remove define_implicit_gradient

* Support closures in custom_root

* revert api-test

* another test

* jit tests

* spelling
2019-10-29 16:00:00 -07:00
George Necula
8880e262b0
Use redthedocs links for Colabs (#1572)
Steer the documentation readers to readthedocs.
Also, minor fixes to the wording of How_jax_primitives_work, suggested by Dougal
2019-10-29 08:53:35 +01:00
Sharad Vikram
e2e4e6e955 Fix title toc structure 2019-10-28 13:59:16 -07:00
Sharad Vikram
5d56999913 Add custom interpreter notebook 2019-10-28 13:58:55 -07:00
George Necula
0ffcd769ef
Add sklearn to Travis, for documentation building. (#1547)
* Add sklearn to Travis, for documentation building.
* Add score_matching to auto-built notebooks
2019-10-21 23:24:16 +02:00
Peter Hawkins
c485a3cc50
Remove stale reference to lapax.py. (#1546)
Add some missing documentation references.
2019-10-21 13:47:36 -04:00
Peter Hawkins
9c23a95e6a
Add i0e and i1e Bessel functions. (#1541) 2019-10-21 10:30:55 -04:00
George Necula
eae59d0b2c
Moved all notebooks to docs/notebooks. (#1493)
* Moved all notebooks to docs/notebooks.

Now all notebooks are in the same place, thus all are subject
to auto-doc generation at readthedocs.io and to automated testing
with travis.

Some notebooks are too slow, exclude them at docs/conf.py:exclude_patterns.

Cleanup a bit the section headings in notebooks so that they show
up well in readtehdocs.io.

* Increase the cell timeout for executing notebooks
* Exclude also the neural network notebook from auto-generation (timing out)
* Disable the score_matching notebook from auto-doc (travis does not have sklearn)
2019-10-17 08:58:25 +02:00
Trevor Cai
c50495b272 Fix List rendering in RTD for primitives.ipynb (#1501)
Colab doesn't require a newline before unordered list in Markdown; RTD
does.
2019-10-14 10:55:13 -07:00
Peter Hawkins
78132c150d Document all_to_all and ppermute. 2019-10-10 15:19:17 -04:00
George Necula
b2493a1ede
Merge pull request #1474 from gnecula/documentation
Create developer documentation.
2019-10-10 09:03:01 +02:00
George Necula
a9d9504348 Fixes to the documentation
* Included "Building from source" in README.md
* Added references from docs/README.md to docs/developer.rst
2019-10-09 17:45:09 +02:00
George Necula
e42c010605 Create developer documentation.
* Moved out of README.md some developer-only stuff to docs/developer.rst.
    * Added documentation about building the documentation
2019-10-09 17:24:01 +02:00
George Necula
c9d984b328 Fixed the readthedocs documentation build
Had to extend the docs/requirements.txt file to install
matplotlb (needed by the Gotchas notebook) and ".",
needed by everything. This results in a reduction
of the sphinx warnings from 3300 to 1200!
2019-10-09 14:43:42 +02:00
George Necula
41457633cc Increase readthedocs/nbsphinx timeout
It seems tht RTD is timing out when compiling the How_JAX_primitives_work noteboook.
2019-10-07 18:07:55 +02:00