Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version.
* Remove backward compatibility code.
* Use CustomCallWithLayout instead of CustomCall.
* Mention jaxlib version bump in changelog.
* Moved CHANGELOG to docs
This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.
* Actually add the CHANGELOG.rst
* Added reminder comments to the CHANGELOG.rst
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
narrative content in the docstring but it should go somewhere else.
* Add jax.numpy.vectorize
This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
`static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).
Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).
* Add deprecation warning to jax.experimental.vectorize
* improve implementation
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to
- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
* Implement np.linalg.matrix_rank
* Test np.linalg.matrix_rank
* Use helper numpy testing function
* Fix issue with 1D matrix rank procedure
* Add new tests for 1D matrices and jit
* Do not check dtypes to circumvent int32 vs int64
* Include documentation for matrix_rank
* Fix ordering
* Use np.sum
* Implement numpy.linalg.matrix_power
* Write tests for numpy.linalg.matrix_power
* Check for input matrix shapes
* Move to matrix-multiplication operator in matrix power
* Improve error messages and directly use broadcasting
* Include matrix_power in documentation
* Issue1635 expm
Implemented expm using Pade approximation. The implmentation is
wrapped using custom_transforms. Frechet derivatives are provided
using defvjp.
* Issue1635 expm
Implemented expm using Pade approximation based on tf.linalg.expm.
* Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635""
This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing
changes made to b63c190c7671ebb9b911a52dcc203285c56a8051.
* Issue1635 expm testing
Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm
* travis build Issue1635 branch
* Issue1635 expm testing
Use rand_small to get numerical agreeming
* Issue1635 expm testing
Use @jit to prevent recompilation
* Issue1635 expm testing
Use rand_small to get numerical agreement
* Revert "travis build Issue1635 branch"
This reverts commit 6139772555e3af79dc0307fce88838a480e42d38.
* Issue1635
Replace construct with jax.numpy.select
* Issue1635
Restructure to support the docstring from SciPy
* Issue1635
Restructure to support the docstring from SciPy
* Issue1635
Remove the note that sparsity is not exploited because JAX does not support sparsity.
* Issue1635 expm
Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly.
* Issue1635
Rename argument, make it positional. Update documentation
Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>
* WIP: real valued fft functions
Note: The transpose rule is not correct yet (hence the failing tests).
* Fix transpose rules for rfft and irfft
* Typo fix
* fix test failures in x64 mode
* Add 1d/2d real fft functions, plus docs