2127 Commits

Author SHA1 Message Date
Peter Hawkins
ddd76b803b
Add a jax.profiler module that exposes the new TensorFlow profiler in… (#2354)
* Add a jax.profiler module that exposes the new TensorFlow profiler integration.

* Fix documentation.
2020-03-05 12:07:57 -05:00
Peter Hawkins
73443d28d8
Update release notes for jaxlib 0.1.40. (#2359) 2020-03-05 08:41:07 -05:00
Stephan Hoyer
61a76014fd
Recommend using pip for jaxlib more strongly in the dev guide (#2333)
Right now our developer guide suggests that it "may" work but it actually is
almost always the recommended choice.
2020-02-28 10:40:18 -08:00
George Necula
b21e7530f3
Added the optix.py documentation to RTD (#2312)
Issue: #2297
2020-02-26 21:30:21 +01:00
joao guilherme
f6e1d01f94
JIT differentiate -> JIT compile (#2279) 2020-02-23 22:04:02 +01:00
George Necula
89514f9278
Moved CHANGELOG to docs (#2252)
* Moved CHANGELOG to docs

This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.

* Actually add the CHANGELOG.rst

* Added reminder comments to the CHANGELOG.rst
2020-02-23 19:18:06 +01:00
Stephan Hoyer
48f2a41453
Minor fixes to docs related to jax.numpy.vectorize (#2278)
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
  original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
  parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
  narrative content in the docstring but it should go somewhere else.
2020-02-23 19:10:39 +01:00
Matthew Johnson
96b66ac976
fix typo in autodiff cookbook 2020-02-19 12:37:59 -08:00
Peter Hawkins
b6e8341176
Improve developer documentation. (#2247)
Add Python version test to build.py.
2020-02-17 11:24:03 -08:00
George Necula
fcd949b695
Added blank line to autodiff cookbook to trigger an enumeration 2020-02-17 16:01:10 +01:00
Mathis Gerdes
3a0690fa11 Correct sign mistake in complex autodiff docs. 2020-02-17 14:28:56 +01:00
George Necula
42bf313fa1 Fixed the name of the excluded notebook
Issue: #2236
2020-02-15 12:10:30 +01:00
George Necula
1c6cd25417 Temporarily disable XLA_in_Python notebook, pending fixing of bug
Issue: #2236
2020-02-15 12:03:40 +01:00
George Necula
370558def3 Removed a couple of slow notebooks from RTD auto-rendering.
Trying to address the timeouts in RTD rendering.

Also fixed bad itemized list in autodiff cookbook, and a few minor warnings:
Issue: #2092
2020-02-15 11:43:10 +01:00
George Necula
938336e08a
Merge pull request #2216 from gnecula/documentation
Added the first draft of the Jaxpr documentation.
2020-02-14 07:23:47 +01:00
George Necula
20dbc62277 Updated docstrings based on review comments 2020-02-13 09:28:01 +01:00
Stephan Hoyer
00140f07e2
Add jax.numpy.vectorize (#2146)
* Add jax.numpy.vectorize

This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
  `static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).

Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).

* Add deprecation warning to jax.experimental.vectorize

* improve implementation
2020-02-12 14:09:37 -08:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
Anselm Levskaya
28e802c6f1
Fix Gotchas notebook regarding control flow differentiation. (#2194) 2020-02-10 16:39:27 -08:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
George Necula
a955fd9dee Updated notebook that refered to freevars 2020-02-03 19:57:08 +01:00
Peter Hawkins
3c9ae5e221
Add jax.scipy.stats.logistic to documentation. (#2149) 2020-02-03 12:44:57 -05:00
Colin
d6489103f7
Bump cell execution timeout (#2147)
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to 

- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
2020-02-03 10:15:19 -05:00
Peter Hawkins
91cd20b173
Update documentation and changelog to mention DLPack and array interface support. (#2134) 2020-01-31 11:15:04 -05:00
Peter Hawkins
4803a75c3b
Implement np.block. (#2106)
Rename np.removechars to _removechars; it should never have been public.
2020-01-29 11:55:53 -05:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Peter Hawkins
cfef568dd6
Implement jax.scipy.linalg.block_diag. (#2113) 2020-01-29 11:24:40 -05:00
Daniel Johnson
b68d8b5c4f Clarify instructions for building from source. (#2093)
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
2020-01-28 12:48:37 -08:00
Ziyad Edher
0fca476c54 Implement np.linalg.matrix_rank (#2008)
* Implement np.linalg.matrix_rank

* Test np.linalg.matrix_rank

* Use helper numpy testing function

* Fix issue with 1D matrix rank procedure

* Add new tests for 1D matrices and jit

* Do not check dtypes to circumvent int32 vs int64

* Include documentation for matrix_rank

* Fix ordering

* Use np.sum
2020-01-26 11:29:33 -08:00
Ziyad Edher
0c95c26e97 Implement np.linalg.matrix_power (#2042)
* Implement numpy.linalg.matrix_power

* Write tests for numpy.linalg.matrix_power

* Check for input matrix shapes

* Move to matrix-multiplication operator in matrix power

* Improve error messages and directly use broadcasting

* Include matrix_power in documentation
2020-01-24 13:52:40 -08:00
Matthew Johnson
6b5ef898dc
fix autodiff cookbook np.allclose tuple bug (#2055) 2020-01-23 10:21:55 -08:00
Sri Hari Krishna Narayanan
03b2ae6d59 Issue1635 expm (#1940)
* Issue1635 expm
Implemented expm using Pade approximation. The implmentation is
wrapped using custom_transforms. Frechet derivatives are provided
using defvjp.

* Issue1635 expm

Implemented expm using Pade approximation based on tf.linalg.expm.

* Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635""

This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing
changes made to b63c190c7671ebb9b911a52dcc203285c56a8051.

* Issue1635 expm testing

Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm

* travis build Issue1635 branch

* Issue1635 expm testing

Use rand_small to get numerical agreeming

* Issue1635 expm testing
Use @jit to prevent recompilation

* Issue1635 expm testing

Use rand_small to get numerical agreement

* Revert "travis build Issue1635 branch"

This reverts commit 6139772555e3af79dc0307fce88838a480e42d38.

* Issue1635

Replace construct with  jax.numpy.select

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Remove the note that sparsity is not exploited because JAX does not support sparsity.

* Issue1635 expm

Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly.

* Issue1635

Rename argument, make it positional. Update documentation

Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>
2020-01-21 21:11:51 -08:00
Surya Bulusu
71323b5d02 changes loop_mjp(f, x, M) (#2013)
a minor change: we iterate over M and not S
2020-01-16 17:47:15 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Stephan Hoyer
a5b6e8abf3
Real valued FFTs (#1657)
* WIP: real valued fft functions

Note: The transpose rule is not correct yet (hence the failing tests).

* Fix transpose rules for rfft and irfft

* Typo fix

* fix test failures in x64 mode

* Add 1d/2d real fft functions, plus docs
2020-01-13 14:59:00 -08:00
archis
05f09fc935 added rfftfreq, tests, and documentation link. 2020-01-10 16:31:47 -08:00
archis
1e8c9384f0 added fftfreq, corresponding tests, and documentation links. 2020-01-06 22:56:00 -08:00
Archis Joglekar
ca15512932 added fft2 and ifft2, corresponding tests, and documentation links. (#1939) 2020-01-04 18:21:30 -08:00
Archis Joglekar
d9c6a5f4a8 fft and ifft implementation (#1926)
* first commit with placeholders for tests

* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis

added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds

* removed placeholders
added functions to .rst file
2020-01-02 17:35:22 -08:00
John Mellor
fd6067471e Fix minor typo in Common_Gotchas_in_JAX.ipynb
Moved misplaced backtick
2020-01-02 07:43:59 -08:00
Yo
322ebe7c9b Update docs/conf.py 2019-12-30 11:27:12 -08:00
flowed
e0693fe649 Fix Typos 2019-12-30 11:27:12 -08:00
Matthew Johnson
f5723848d3
fix error in autodiff cookbook: 3x not 2x 2019-12-30 07:36:36 -08:00
David Bieber
30bede1f6a fix typo in autodiff cookbook (#1921) 2019-12-27 11:02:06 -08:00
Peter Hawkins
698babf9ec
Implement jax.numpy.nonzero and 1-argument jax.numpy.where. (#1905)
* Implement jax.numpy.nonzero.

* Implement the one-argument form of np.where.

* Fix output type and error message.

* Add lax_description strings to where and nonzero.
2019-12-20 18:42:33 -05:00
Matthew Johnson
8dad859e04 streamline readme, add pmap 2019-12-14 10:32:04 -08:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Peter Hawkins
e87d9718c3
Support IntEnum values as arguments to JAX functions. (#1840)
* Support IntEnum values as arguments to JAX functions.

When abstractifying a Python value, search the method-resolution order (MRO) of the type rather than only looking at the value's own type. IntEnum instances are subclasses of int, so this allows us to correctly handle them as integers, much as NumPy itself does.
2019-12-11 12:27:11 -05:00
tamaranorman
26e863923a Support atrous conv in same padded convolution and add warning if use transposed convolution with same or valid padding. (#1806)
PiperOrigin-RevId: 283517237
2019-12-09 08:06:59 -08:00
Peter Hawkins
d958f3007d
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.

NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.

This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,

```
import numpy as onp
from jax import numpy as np

In [1]: onp.promote_types(onp.float32, onp.int32)   
Out[1]: dtype('float64')

In [2]: onp.promote_types(onp.float16, onp.int64)   
Out[2]: dtype('float64')

In [3]: np.promote_types(onp.float32, onp.int32)    
Out[3]: dtype('float32')

In [4]: np.promote_types(onp.float16, onp.int64)    
Out[4]: dtype('float16')
```

This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00