* Moved CHANGELOG to docs
This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.
* Actually add the CHANGELOG.rst
* Added reminder comments to the CHANGELOG.rst
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
narrative content in the docstring but it should go somewhere else.
* Add jax.numpy.vectorize
This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
`static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).
Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).
* Add deprecation warning to jax.experimental.vectorize
* improve implementation
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to
- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
* Implement np.linalg.matrix_rank
* Test np.linalg.matrix_rank
* Use helper numpy testing function
* Fix issue with 1D matrix rank procedure
* Add new tests for 1D matrices and jit
* Do not check dtypes to circumvent int32 vs int64
* Include documentation for matrix_rank
* Fix ordering
* Use np.sum
* Implement numpy.linalg.matrix_power
* Write tests for numpy.linalg.matrix_power
* Check for input matrix shapes
* Move to matrix-multiplication operator in matrix power
* Improve error messages and directly use broadcasting
* Include matrix_power in documentation
* Issue1635 expm
Implemented expm using Pade approximation. The implmentation is
wrapped using custom_transforms. Frechet derivatives are provided
using defvjp.
* Issue1635 expm
Implemented expm using Pade approximation based on tf.linalg.expm.
* Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635""
This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing
changes made to b63c190c7671ebb9b911a52dcc203285c56a8051.
* Issue1635 expm testing
Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm
* travis build Issue1635 branch
* Issue1635 expm testing
Use rand_small to get numerical agreeming
* Issue1635 expm testing
Use @jit to prevent recompilation
* Issue1635 expm testing
Use rand_small to get numerical agreement
* Revert "travis build Issue1635 branch"
This reverts commit 6139772555e3af79dc0307fce88838a480e42d38.
* Issue1635
Replace construct with jax.numpy.select
* Issue1635
Restructure to support the docstring from SciPy
* Issue1635
Restructure to support the docstring from SciPy
* Issue1635
Remove the note that sparsity is not exploited because JAX does not support sparsity.
* Issue1635 expm
Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly.
* Issue1635
Rename argument, make it positional. Update documentation
Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>
* WIP: real valued fft functions
Note: The transpose rule is not correct yet (hence the failing tests).
* Fix transpose rules for rfft and irfft
* Typo fix
* fix test failures in x64 mode
* Add 1d/2d real fft functions, plus docs
* first commit with placeholders for tests
* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis
added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds
* removed placeholders
added functions to .rst file
* Implement jax.numpy.nonzero.
* Implement the one-argument form of np.where.
* Fix output type and error message.
* Add lax_description strings to where and nonzero.
* Support IntEnum values as arguments to JAX functions.
When abstractifying a Python value, search the method-resolution order (MRO) of the type rather than only looking at the value's own type. IntEnum instances are subclasses of int, so this allows us to correctly handle them as integers, much as NumPy itself does.
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.