58 Commits

Author SHA1 Message Date
Jake Vanderplas
f9a3aed0b6
Implement numpy.linalg.multi_dot (#2726)
* Implement numpy.linalg.multi_dot

* Thread precision through multi_dot
2020-04-15 20:35:54 -04:00
Matthew Johnson
7a4c4d555c use custom_jvp for internal functions 2020-03-29 20:48:08 -07:00
Matthew Johnson
c3e3d4807e temporarily revert parts of #2026 pending bug fix 2020-03-25 20:19:49 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Stephan Hoyer
6cceb2c778
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747

The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.

In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.

In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:

    from functools import partial
    import jax.scipy as jsp
    from jax import lax
    import jax.numpy as np
    import numpy as onp
    import jax

    def loss(solve):
      def f(a, b):
        return solve(a, b).sum()
      return f

    rs = onp.random.RandomState(0)
    N = 500
    K = 1
    a = rs.randn(N, N)
    a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
    b = jax.device_put(rs.randn(N, K))

    # general matrix solve
    grad = jax.jit(jax.grad(loss(np.linalg.solve)))
    grad(a, b).block_until_ready()
    %timeit grad(a, b).block_until_ready()
    # N=500, K=1: 11.4 ms -> 3.63 ms

    # positive definite solve
    grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
    grad(a, b).block_until_ready()
    %timeit grad(a, b).block_until_ready()
    # N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
Stephan Hoyer
aca7bccefd
Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve (#2144)
* Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve

This single implementation supports broadcasting like NumPy in both the NumPy
and SciPy interfaces to LU solve, even though only original NumPy supports
broadcasting.

This change is technical backwards incompatible in the SciPy wrapper, which
previously supported adding extra dimensions to the end of `b`, e.g.,
`b.shape == (8, 4, 2)` when `a.shape == (8, 8)`. There was a testcase for this,
but it isn't documented in either JAX or SciPy.

* fix recursive import

* Use np.vectorize instead of experimental.vectorize
2020-02-12 17:05:18 -08:00
Guillem Orellana Trullols
aa0ca27062
Implementation of np.linalg.tensorsolve. (#2119)
* Tensorsolve implementation

* Tests working for tensorsolve #1999

* Moved tensorsolve to third party directory
2020-02-09 14:35:09 -08:00
StephenHogg
a0e1804e43
Implementation of np.linalg.{cond, tensorinv} (#2125)
* add np.linalg.cond in a third_party module

* remove unnecessary type maps

* rename cond.py to linalg.py for consistency

* shift LICENSE to correct directory; formatting changes; completed testing

* Add implementation and testing for tensorinv

* fix tests for tensorinv to stop build stalling

* blank __init__.py; add extra testing for cond; remove index assignments

* jax.lax.cond is breaking on jax.numpy.linalg.norm

* fix control flow issues; update tests to use curried functions

* clean up imports; remove commented code

* remove control flow from tests; remove unneeded functions
2020-02-07 12:20:04 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Ziyad Edher
0fca476c54 Implement np.linalg.matrix_rank (#2008)
* Implement np.linalg.matrix_rank

* Test np.linalg.matrix_rank

* Use helper numpy testing function

* Fix issue with 1D matrix rank procedure

* Add new tests for 1D matrices and jit

* Do not check dtypes to circumvent int32 vs int64

* Include documentation for matrix_rank

* Fix ordering

* Use np.sum
2020-01-26 11:29:33 -08:00
Ziyad Edher
0c95c26e97 Implement np.linalg.matrix_power (#2042)
* Implement numpy.linalg.matrix_power

* Write tests for numpy.linalg.matrix_power

* Check for input matrix shapes

* Move to matrix-multiplication operator in matrix power

* Improve error messages and directly use broadcasting

* Include matrix_power in documentation
2020-01-24 13:52:40 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Tuan Nguyen
7f4b641c6d Additional doc for np.linalg.pinv (#1820)
* starter code

* Update scipy_stats_test.py

* Update __init__.py

* Update scipy_stats_test.py

* starter code for pinv

* fix transpose, add more test cases & complex dtype

* update test to latest format

* update default rcond

* Update linalg.py

* bigger test size

* Update linalg.py

* Update linalg_test.py

* fix float issue

* Update linalg.py

* smaller test cases

* Update linalg_test.py

* try not forcing float

* explicit cast

* try a different casting

* try another casting

* Update doc for pinv

* Update linalg.py
2019-12-09 09:56:26 -08:00
Tuan Nguyen
2316a29ae9 Implement np.linalg.pinv (#1656)
* starter code

* Update scipy_stats_test.py

* Update __init__.py

* Update scipy_stats_test.py

* starter code for pinv

* fix transpose, add more test cases & complex dtype

* update test to latest format

* update default rcond

* Update linalg.py

* bigger test size

* Update linalg.py

* Update linalg_test.py

* fix float issue

* Update linalg.py

* smaller test cases

* Update linalg_test.py

* try not forcing float

* explicit cast

* try a different casting

* try another casting
2019-12-03 11:15:39 -08:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Stephen Tu
39daf07de3 Add trivial implementations of eigvals/eigvalsh (#1604)
* Add trivial implementations of eigvals/eigvalsh

The implementations simply delegate to eig/eigh.

* Enable eigvalsh test on TPU/GPU
2019-10-30 22:29:56 -04:00
David Pfau
7a347dede9
Added TODO to fix slogdet grad for complex types 2019-09-17 18:58:34 +01:00
David Pfau
3e5f7869ea
Fixed prefix for np.linalg.solve 2019-09-16 21:27:55 +01:00
David Pfau
36aaaf9f55
Merge branch 'master' into master 2019-09-16 20:54:04 +01:00
David Pfau
d30108d773
Added jit to slogdet
This resolves the prior merge conflict
2019-09-16 20:52:36 +01:00
David Pfau
0171eb7c68
Replace jvp for det with jvp for slogdet 2019-09-14 14:35:27 +01:00
David Pfau
c2750c1be9
Extended jvp for det to handle inputs with >2 dims 2019-09-14 14:30:45 +01:00
Peter Hawkins
a5f67d553d Fix incorrect slogdet parity calculation in presence of batch dimensions. 2019-09-11 08:19:26 -04:00
David Pfau
b819bff1d3
Imports for custom_transforms and defjvp
Can't really do this if we don't have the right imports...
2019-09-05 15:22:36 +01:00
David Pfau
1c264db3a3
Slightly cleaner implementation of jvp for det
Replaced np.dot/np.linalg.inv with a single np.linalg.solve
2019-08-31 03:18:24 +01:00
David Pfau
ae2c39d081
Added custom jvp for np.linalg.det
For faster gradient computation for determinants, added closed-form expressing for Jacobian-vector product. Still needs a test.
2019-08-31 02:13:12 +01:00
Peter Hawkins
5ac356d680 Add support for batched triangular solve and LU decomposition on GPU using cuBlas. 2019-08-08 13:34:53 -04:00
Peter Hawkins
5336722ad7 Support rank > 2 inputs to np.linalg.norm if axis==None and ord==None.
Matches an undocumented NumPy behavior:
https://github.com/numpy/numpy/issues/14215
2019-08-07 09:21:07 -04:00
Peter Hawkins
efdcef88c7 Remove experimental warning from linalg routines.
There's no particular reason to scare people with the experimental warning any longer; we don't know of any bugs here.
2019-07-23 21:13:14 -04:00
Peter Hawkins
12e622bbc1 Implement np.linalg.inv in terms np.linalg.solve.
i.e. use an LU decomposition instead of a QR decomposition, now that an LU decomposition is available on all platforms.
2019-06-28 15:49:38 -04:00
Peter Hawkins
a396276e78 Add unit_diagonal option to lax_linalg.solve_triangular.
LAPACK and cuBLAS both support treating the diagonal of a triangular matrix as 1 and ignoring the actual matrix contents. Plumb this ability through to lax.
2019-06-25 15:27:37 -04:00
Peter Hawkins
f53ede7e5f Add batching support to numpy.linalg.solve. 2019-06-17 20:52:43 -04:00
Peter Hawkins
a96944eb53 Implement np.linalg.eig on CPU.
Fixes #639.
2019-05-13 15:59:58 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Anselm Levskaya
8cd3f448d5 fix missing symmetrize_input arg 2019-02-14 00:40:42 -08:00
Anselm Levskaya
8a84ae8d2a added jvp rule for eigh, tests 2019-02-13 21:50:39 -08:00
Matthew Johnson
78fd9e1a10 debug cholesky grad, remove stale dot_general check 2019-02-13 09:18:28 -08:00
Peter Hawkins
55acfb15e6 Implement np.linalg.norm. 2019-02-07 10:51:55 -05:00
Peter Hawkins
652f0df017 Fix some TODOs in linalg: use gather instead of matmul to permute matrix rows. 2019-02-04 21:48:03 -05:00
vishwakftw
e16f31e94f Merge branch 'master' of https://github.com/google/jax into svd 2019-01-08 09:30:05 +05:30
vishwakftw
954b047bea Address review comments 2019-01-08 09:24:48 +05:30
Peter Hawkins
1c2cff15d2 Finish implementation of symmetric eigendecomposition on CPU:
* add test case.
* add double and complex64 implementations.

Also add logic to all linalg methods to both coerce arguments to arrays, and to promote to an inexact (float or complex) type if the argument is not inexact.
2019-01-07 19:20:42 -05:00
vishwakftw
484db1e15f Add SVD for float and double types 2019-01-05 11:13:08 +05:30
Peter Hawkins
06135fa6f5 Implement numpy.linalg.solve and scipy.linalg.solve.
Make Cholesky and TriangularSolve work for complex numbers on CPU. The HLO implementations are broken for complex numbers on GPU/TPU, so no tests enabled for these yet.
2018-12-21 16:29:45 -05:00
Peter Hawkins
a4386457e2 Fix test failures due to type mismatches in linear algebra tests.
Minor code cleanups.
2018-12-21 15:18:34 -05:00
Peter Hawkins
b68c93d37f Implement np.linalg.slogdet.
Change implementation of np.linalg.logdet to call np.linalg.slogdet.

Add support for complex64 LU decomposition.
2018-12-20 22:18:20 -05:00
Peter Hawkins
dfdc2e3806 Add LU decomposition implementation backed by LAPACK on the CPU platform.
Implement np.linalg.det, and scipy.linalg.{lu,lu_factor,det}.

Add missing abstractification to loop arguments.
Implement XLA abstractification rules for AbstractTuple, ConcreteArray, and ShapedArray.
2018-12-20 18:45:34 -05:00
Peter Hawkins
ab1ebc6bad Add experimental warning to numpy.linalg and scipy.linalg. 2018-12-17 17:39:46 -05:00