2140 Commits

Author SHA1 Message Date
George Necula
e66e569947
Minor update to docsl trigger readthedocs (#2433) 2020-03-17 09:07:14 +01:00
George Necula
3362591c79 Updated CHANGELOG 2020-03-17 06:51:01 +01:00
Matthew Johnson
ae921c7a4a update changelog 2020-03-15 11:15:13 -07:00
Matthew Johnson
7f0463e2c9
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,) ] a
    in b }

The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!

Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,)
                        input_shape=(3,) ] a
    in b }

That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)

But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!

That's exactly what this commit does!

Co-authored-by: Roy Frostig <frostig@google.com>

Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
Peter Hawkins
cf41f7682f
Add np.linalg and np.fft functions to documentation. (#2407) 2020-03-12 15:05:59 -04:00
George Necula
61b430eeb4
Added more documentation for how to fix notebook build failures (#2404) 2020-03-12 10:59:30 +01:00
Matthew Johnson
cdf188af2f
add raises-exception notebook cell metadata (#2402) 2020-03-11 09:42:25 -07:00
Jordan Hoffmann
ffa03403c9
Add jnp vs np out of bounds indexing to Sharp Bits nb (#2378) 2020-03-10 17:53:43 -07:00
us
8339511eb5
Implement NumPy sorting routines. (#2318)
Implement `np.msort`.
Related issue: #2079
2020-03-09 10:07:12 -04:00
Lucas Theis
05e5ccfdd5
Minor fix to docs which mentioned IOHW where it should be OIHW (#2381) 2020-03-09 06:08:32 -07:00
Skye Wanderman-Milne
a1fa6296cc
Document jax.device_put (#2366) 2020-03-05 14:45:01 -08:00
Peter Hawkins
eeffbca45e
Add profiler documentation. (#2365) 2020-03-05 16:07:17 -05:00
Peter Hawkins
64b1da9d48
Updates to jax that were deferred until after jaxlib 0.1.40 became th… (#2362)
* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version.
* Remove backward compatibility code.
* Use CustomCallWithLayout instead of CustomCall.

* Mention jaxlib version bump in changelog.
2020-03-05 13:10:20 -05:00
Peter Hawkins
ddd76b803b
Add a jax.profiler module that exposes the new TensorFlow profiler in… (#2354)
* Add a jax.profiler module that exposes the new TensorFlow profiler integration.

* Fix documentation.
2020-03-05 12:07:57 -05:00
Peter Hawkins
73443d28d8
Update release notes for jaxlib 0.1.40. (#2359) 2020-03-05 08:41:07 -05:00
Stephan Hoyer
61a76014fd
Recommend using pip for jaxlib more strongly in the dev guide (#2333)
Right now our developer guide suggests that it "may" work but it actually is
almost always the recommended choice.
2020-02-28 10:40:18 -08:00
George Necula
b21e7530f3
Added the optix.py documentation to RTD (#2312)
Issue: #2297
2020-02-26 21:30:21 +01:00
joao guilherme
f6e1d01f94
JIT differentiate -> JIT compile (#2279) 2020-02-23 22:04:02 +01:00
George Necula
89514f9278
Moved CHANGELOG to docs (#2252)
* Moved CHANGELOG to docs

This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.

* Actually add the CHANGELOG.rst

* Added reminder comments to the CHANGELOG.rst
2020-02-23 19:18:06 +01:00
Stephan Hoyer
48f2a41453
Minor fixes to docs related to jax.numpy.vectorize (#2278)
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
  original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
  parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
  narrative content in the docstring but it should go somewhere else.
2020-02-23 19:10:39 +01:00
Matthew Johnson
96b66ac976
fix typo in autodiff cookbook 2020-02-19 12:37:59 -08:00
Peter Hawkins
b6e8341176
Improve developer documentation. (#2247)
Add Python version test to build.py.
2020-02-17 11:24:03 -08:00
George Necula
fcd949b695
Added blank line to autodiff cookbook to trigger an enumeration 2020-02-17 16:01:10 +01:00
Mathis Gerdes
3a0690fa11 Correct sign mistake in complex autodiff docs. 2020-02-17 14:28:56 +01:00
George Necula
42bf313fa1 Fixed the name of the excluded notebook
Issue: #2236
2020-02-15 12:10:30 +01:00
George Necula
1c6cd25417 Temporarily disable XLA_in_Python notebook, pending fixing of bug
Issue: #2236
2020-02-15 12:03:40 +01:00
George Necula
370558def3 Removed a couple of slow notebooks from RTD auto-rendering.
Trying to address the timeouts in RTD rendering.

Also fixed bad itemized list in autodiff cookbook, and a few minor warnings:
Issue: #2092
2020-02-15 11:43:10 +01:00
George Necula
938336e08a
Merge pull request #2216 from gnecula/documentation
Added the first draft of the Jaxpr documentation.
2020-02-14 07:23:47 +01:00
George Necula
20dbc62277 Updated docstrings based on review comments 2020-02-13 09:28:01 +01:00
Stephan Hoyer
00140f07e2
Add jax.numpy.vectorize (#2146)
* Add jax.numpy.vectorize

This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
  `static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).

Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).

* Add deprecation warning to jax.experimental.vectorize

* improve implementation
2020-02-12 14:09:37 -08:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
Anselm Levskaya
28e802c6f1
Fix Gotchas notebook regarding control flow differentiation. (#2194) 2020-02-10 16:39:27 -08:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
George Necula
a955fd9dee Updated notebook that refered to freevars 2020-02-03 19:57:08 +01:00
Peter Hawkins
3c9ae5e221
Add jax.scipy.stats.logistic to documentation. (#2149) 2020-02-03 12:44:57 -05:00
Colin
d6489103f7
Bump cell execution timeout (#2147)
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to 

- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
2020-02-03 10:15:19 -05:00
Peter Hawkins
91cd20b173
Update documentation and changelog to mention DLPack and array interface support. (#2134) 2020-01-31 11:15:04 -05:00
Peter Hawkins
4803a75c3b
Implement np.block. (#2106)
Rename np.removechars to _removechars; it should never have been public.
2020-01-29 11:55:53 -05:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Peter Hawkins
cfef568dd6
Implement jax.scipy.linalg.block_diag. (#2113) 2020-01-29 11:24:40 -05:00
Daniel Johnson
b68d8b5c4f Clarify instructions for building from source. (#2093)
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
2020-01-28 12:48:37 -08:00
Ziyad Edher
0fca476c54 Implement np.linalg.matrix_rank (#2008)
* Implement np.linalg.matrix_rank

* Test np.linalg.matrix_rank

* Use helper numpy testing function

* Fix issue with 1D matrix rank procedure

* Add new tests for 1D matrices and jit

* Do not check dtypes to circumvent int32 vs int64

* Include documentation for matrix_rank

* Fix ordering

* Use np.sum
2020-01-26 11:29:33 -08:00
Ziyad Edher
0c95c26e97 Implement np.linalg.matrix_power (#2042)
* Implement numpy.linalg.matrix_power

* Write tests for numpy.linalg.matrix_power

* Check for input matrix shapes

* Move to matrix-multiplication operator in matrix power

* Improve error messages and directly use broadcasting

* Include matrix_power in documentation
2020-01-24 13:52:40 -08:00
Matthew Johnson
6b5ef898dc
fix autodiff cookbook np.allclose tuple bug (#2055) 2020-01-23 10:21:55 -08:00
Sri Hari Krishna Narayanan
03b2ae6d59 Issue1635 expm (#1940)
* Issue1635 expm
Implemented expm using Pade approximation. The implmentation is
wrapped using custom_transforms. Frechet derivatives are provided
using defvjp.

* Issue1635 expm

Implemented expm using Pade approximation based on tf.linalg.expm.

* Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635""

This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing
changes made to b63c190c7671ebb9b911a52dcc203285c56a8051.

* Issue1635 expm testing

Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm

* travis build Issue1635 branch

* Issue1635 expm testing

Use rand_small to get numerical agreeming

* Issue1635 expm testing
Use @jit to prevent recompilation

* Issue1635 expm testing

Use rand_small to get numerical agreement

* Revert "travis build Issue1635 branch"

This reverts commit 6139772555e3af79dc0307fce88838a480e42d38.

* Issue1635

Replace construct with  jax.numpy.select

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Remove the note that sparsity is not exploited because JAX does not support sparsity.

* Issue1635 expm

Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly.

* Issue1635

Rename argument, make it positional. Update documentation

Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>
2020-01-21 21:11:51 -08:00
Surya Bulusu
71323b5d02 changes loop_mjp(f, x, M) (#2013)
a minor change: we iterate over M and not S
2020-01-16 17:47:15 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Stephan Hoyer
a5b6e8abf3
Real valued FFTs (#1657)
* WIP: real valued fft functions

Note: The transpose rule is not correct yet (hence the failing tests).

* Fix transpose rules for rfft and irfft

* Typo fix

* fix test failures in x64 mode

* Add 1d/2d real fft functions, plus docs
2020-01-13 14:59:00 -08:00
archis
05f09fc935 added rfftfreq, tests, and documentation link. 2020-01-10 16:31:47 -08:00
archis
1e8c9384f0 added fftfreq, corresponding tests, and documentation links. 2020-01-06 22:56:00 -08:00