50 Commits

Author SHA1 Message Date
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Jake VanderPlas
c66f5dda60 DOC: add missing linalg functionality to docs 2022-03-15 09:55:59 -07:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Peter Hawkins
f3aa5fa92f Document lax.GatherScatterMode.
Recommend the .at[...] property in the docstrings for lax.scatter_ operators.

Add several missing lax.scatter_ operators to the index.
2021-11-22 15:43:02 -05:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Jake VanderPlas
94169b96a8 DOC: add conv_dimension_numbers and ConvGeneralDilatedDimensionNumbers to docs 2021-10-19 17:18:15 -07:00
Peter Hawkins
278ff13b66 Improve implementation of cbrt() in JAX.
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.

Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.

Add testing for `sqrt`/`rsqrt` for more types.

[XLA:Python] Add cbrt to XLA:Python bindings.

PiperOrigin-RevId: 386316949
2021-07-22 14:01:28 -07:00
Roman Novak
bc84c9fe8f Add lax.conv_general_dilated_local 2021-05-13 12:20:35 -07:00
Jake VanderPlas
33fde77bb1 Add lax.reduce_precision() 2021-04-05 09:54:14 -07:00
Jake VanderPlas
749ad95514 DOC: add transformations doc to HTML & reorganize contents 2021-03-08 16:25:04 -08:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Jake VanderPlas
a0b12bba25 DOC: fix minor formatting issues 2021-01-20 14:38:19 -08:00
Benjamin Chetioui
9c56277878 Add "Argument classes" section to jax.lax.rst. 2020-12-01 20:30:09 +01:00
Peter Hawkins
2cf2c719f2 Add documentation to several functions in jax.lax.linalg. 2020-11-05 18:53:47 -05:00
Roman Novak
da0bff2fa8 Add lax.conv_general_dilated_patches 2020-10-20 22:58:53 -07:00
Peter Hawkins
d3db7bd4be Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.

We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.

lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.

The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.

We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.

Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s

After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s

Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s

After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:51:55 -04:00
Peter Hawkins
db43e21b1d Improve documentation for a number of lax functions. 2020-10-14 21:18:09 -04:00
Jake Vanderplas
e0ebb144f9
Add switch and associative_scan to lax docs (#3946) 2020-08-03 12:32:32 -07:00
Chase Roberts
2b7a39f92b
Add pshuffle to docs (#3742) 2020-07-14 09:05:45 -04:00
Peter Hawkins
141fabbbf5
Reimplement argmin/argmax using a single pass variadic reduction. (#3611) 2020-07-01 11:01:22 -04:00
Stephan Hoyer
cc8fbb7669
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
Anselm Levskaya
ca4e396e31
Merge pull request #2853 from levskaya/topkjvp
Add top_k jvp and batching rules and tests
2020-04-29 00:57:29 +10:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Jamie Townsend
75617be803
Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 22:32:52 -07:00
James Bradbury
3f59783352
Add pmean to lax documentation (#2778) 2020-04-20 19:03:43 -07:00
Matthew Johnson
fcc1e76c5a add docstring / reference doc link for axis_index
fixes #2534
2020-03-29 13:56:26 -07:00
Trevor Cai
d11a9ab185
Expose jax.lax.all_gather (#2449)
* Expose jax.lax.all_gather

* add all_gather to RTD
2020-03-19 16:35:00 +01:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Stephan Hoyer
5bcbce744e
Support closures in all arguments of lax.custom_root (#1570)
* WIP: linear solvers

* Draft of lax.linear_solve

* Refactor pytree munging inside lax.root.

The primitive's implementation and JVP rules are now 100% pytree free.

* Fixup linear_solve

* Linearize multiple times in _root_jvp to avoid zeros

* fix deftraced

* add a symmetric argument

* Fixup float64; add a test for symmetric/non-symmetric

* test zeros in linear_solve_jvp

* Revisions per review

* Adjust signature of linear_solve

* restore botched test

* variable names

* WIP: root solve jaxpr

* WIP more tests

* rewrite root

* Root works with jaxprs

* root -> custom_root

* WIP undefined tangent

* Delayed undefined JVP errors

* use raise_on_undefined_tangents inside define_implicit_gradient

* more tests on jvps with undefined tangents

* Remove define_implicit_gradient

* Support closures in custom_root

* revert api-test

* another test

* jit tests

* spelling
2019-10-29 16:00:00 -07:00
Peter Hawkins
c485a3cc50
Remove stale reference to lapax.py. (#1546)
Add some missing documentation references.
2019-10-21 13:47:36 -04:00
Peter Hawkins
78132c150d Document all_to_all and ppermute. 2019-10-10 15:19:17 -04:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Stephan Hoyer
8c628a267b Implement lax.map
Fixes GH-1113
2019-08-05 12:14:05 -07:00
Peter Hawkins
4eb1820ae2 Add documentation to JAX modules. 2019-07-21 15:55:47 -04:00
Matthew Johnson
2f645dd36d remove one lax parallel op from sphinx docs 2019-05-17 13:24:06 -07:00
Matthew Johnson
9b1af47a59 improve documetnation of lax parallel operators 2019-05-17 12:27:09 -07:00
Skye Wanderman-Milne
5d1c014509 Initial FFT support.
This change creates a new fft primitive in lax, and uses it to implement numpy's np.fft.fftn function.

Not-yet-implemented functionality:
- vmap
- 's' argument of fftn
- other numpy np.fft functions

Resolves #505.
2019-05-16 14:37:30 -07:00
Peter Hawkins
c2369ce340 Add lax.scan to docs. 2019-05-13 14:46:35 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Peter Hawkins
86d8915c3d Add Sphinx-generated reference documentation for JAX. 2019-01-16 09:13:31 -05:00