76 Commits

Author SHA1 Message Date
Akihiro Nitta
06170da69a
Use raise from 2020-09-30 01:20:00 +09:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Peter Hawkins
e06a6ab6bf
Add support for negative axes to vmap. (#4111)
* Add support for negative axes to vmap.

* Add workaround for out-of-range vmap axes.
2020-08-24 20:21:19 -04:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake Vanderplas
0a6b715cd4
Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes #3689) (#3698) 2020-07-09 16:31:08 -07:00
Jake Vanderplas
b813ae3aff
Cleanup: record names in get_module_functions (#3697) 2020-07-08 14:44:49 -07:00
Peter Hawkins
e680304dca
Remove warning suppression for tuple and list arguments to reductions. (#3545)
Fix callers.
2020-06-24 15:59:31 -04:00
Jake Vanderplas
c77c0838fe
deflake jax.numpy and add to flake8 check (#3312) 2020-06-03 14:18:48 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
George Necula
a2d6b1aab4
Fix typo in lstsq (#3052) 2020-05-12 09:06:22 +03:00
Jake Vanderplas
db71f3c5fc
Initial implementation of np.linalg.lstsq() via SVD (#2744) 2020-05-11 14:53:17 -07:00
Jake Vanderplas
5dfff9eaab
Cleanup: move _wraps into jax.numpy._utils. (#2987)
Why? This prevents circular imports within the numpy submodule.
2020-05-06 15:17:55 -07:00
Peter Hawkins
2f09e89e72
Update internal aliases to lax_numpy to jnp instead of np. (#2975) 2020-05-05 20:41:57 -04:00
David Pfau
ffaf417b1b
Fix typo in docstring for _cofactor_solve (#2844)
Found a small typo in the description of _cofactor_solve
2020-04-25 08:32:27 -07:00
David Pfau
02b3fc5a7d
Custom derivative for np.linalg.det (#2809)
* Add vjp and jvp rules for jnp.linalg.det

* Add tests for new determinant gradients

* Replace index_update with concatenate in cofactor_solve

This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where).

* Changes to cofactor_solve so it can be transposed

This allows a single JVP rule to give both forward and backward derivatives

* Update det grad tests

All tests pass now - however second derivatives still do not work for nonsingular matrices.

* Add explanation to docstring for _cofactor_solve

* Fixed comment
2020-04-25 08:26:25 -07:00
Peter Hawkins
8fe3c59ced
Add explicit derivative for jax.numpy.linalg.pinv. (#2794)
* Add explicit derivative for jax.numpy.linalg.pinv.

* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
2020-04-22 20:15:04 -04:00
Stephan Hoyer
e6f0b8d87d
Raise an error if stop_gradient is called on non-arrays (#2750)
* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient
2020-04-17 12:42:53 -07:00
Jake Vanderplas
f9a3aed0b6
Implement numpy.linalg.multi_dot (#2726)
* Implement numpy.linalg.multi_dot

* Thread precision through multi_dot
2020-04-15 20:35:54 -04:00
Matthew Johnson
7a4c4d555c use custom_jvp for internal functions 2020-03-29 20:48:08 -07:00
Matthew Johnson
c3e3d4807e temporarily revert parts of #2026 pending bug fix 2020-03-25 20:19:49 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Stephan Hoyer
6cceb2c778
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747

The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.

In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.

In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:

    from functools import partial
    import jax.scipy as jsp
    from jax import lax
    import jax.numpy as np
    import numpy as onp
    import jax

    def loss(solve):
      def f(a, b):
        return solve(a, b).sum()
      return f

    rs = onp.random.RandomState(0)
    N = 500
    K = 1
    a = rs.randn(N, N)
    a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
    b = jax.device_put(rs.randn(N, K))

    # general matrix solve
    grad = jax.jit(jax.grad(loss(np.linalg.solve)))
    grad(a, b).block_until_ready()
    %timeit grad(a, b).block_until_ready()
    # N=500, K=1: 11.4 ms -> 3.63 ms

    # positive definite solve
    grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
    grad(a, b).block_until_ready()
    %timeit grad(a, b).block_until_ready()
    # N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
Stephan Hoyer
aca7bccefd
Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve (#2144)
* Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve

This single implementation supports broadcasting like NumPy in both the NumPy
and SciPy interfaces to LU solve, even though only original NumPy supports
broadcasting.

This change is technical backwards incompatible in the SciPy wrapper, which
previously supported adding extra dimensions to the end of `b`, e.g.,
`b.shape == (8, 4, 2)` when `a.shape == (8, 8)`. There was a testcase for this,
but it isn't documented in either JAX or SciPy.

* fix recursive import

* Use np.vectorize instead of experimental.vectorize
2020-02-12 17:05:18 -08:00
Guillem Orellana Trullols
aa0ca27062
Implementation of np.linalg.tensorsolve. (#2119)
* Tensorsolve implementation

* Tests working for tensorsolve #1999

* Moved tensorsolve to third party directory
2020-02-09 14:35:09 -08:00
StephenHogg
a0e1804e43
Implementation of np.linalg.{cond, tensorinv} (#2125)
* add np.linalg.cond in a third_party module

* remove unnecessary type maps

* rename cond.py to linalg.py for consistency

* shift LICENSE to correct directory; formatting changes; completed testing

* Add implementation and testing for tensorinv

* fix tests for tensorinv to stop build stalling

* blank __init__.py; add extra testing for cond; remove index assignments

* jax.lax.cond is breaking on jax.numpy.linalg.norm

* fix control flow issues; update tests to use curried functions

* clean up imports; remove commented code

* remove control flow from tests; remove unneeded functions
2020-02-07 12:20:04 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Ziyad Edher
0fca476c54 Implement np.linalg.matrix_rank (#2008)
* Implement np.linalg.matrix_rank

* Test np.linalg.matrix_rank

* Use helper numpy testing function

* Fix issue with 1D matrix rank procedure

* Add new tests for 1D matrices and jit

* Do not check dtypes to circumvent int32 vs int64

* Include documentation for matrix_rank

* Fix ordering

* Use np.sum
2020-01-26 11:29:33 -08:00
Ziyad Edher
0c95c26e97 Implement np.linalg.matrix_power (#2042)
* Implement numpy.linalg.matrix_power

* Write tests for numpy.linalg.matrix_power

* Check for input matrix shapes

* Move to matrix-multiplication operator in matrix power

* Improve error messages and directly use broadcasting

* Include matrix_power in documentation
2020-01-24 13:52:40 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Tuan Nguyen
7f4b641c6d Additional doc for np.linalg.pinv (#1820)
* starter code

* Update scipy_stats_test.py

* Update __init__.py

* Update scipy_stats_test.py

* starter code for pinv

* fix transpose, add more test cases & complex dtype

* update test to latest format

* update default rcond

* Update linalg.py

* bigger test size

* Update linalg.py

* Update linalg_test.py

* fix float issue

* Update linalg.py

* smaller test cases

* Update linalg_test.py

* try not forcing float

* explicit cast

* try a different casting

* try another casting

* Update doc for pinv

* Update linalg.py
2019-12-09 09:56:26 -08:00
Tuan Nguyen
2316a29ae9 Implement np.linalg.pinv (#1656)
* starter code

* Update scipy_stats_test.py

* Update __init__.py

* Update scipy_stats_test.py

* starter code for pinv

* fix transpose, add more test cases & complex dtype

* update test to latest format

* update default rcond

* Update linalg.py

* bigger test size

* Update linalg.py

* Update linalg_test.py

* fix float issue

* Update linalg.py

* smaller test cases

* Update linalg_test.py

* try not forcing float

* explicit cast

* try a different casting

* try another casting
2019-12-03 11:15:39 -08:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Stephen Tu
39daf07de3 Add trivial implementations of eigvals/eigvalsh (#1604)
* Add trivial implementations of eigvals/eigvalsh

The implementations simply delegate to eig/eigh.

* Enable eigvalsh test on TPU/GPU
2019-10-30 22:29:56 -04:00
David Pfau
7a347dede9
Added TODO to fix slogdet grad for complex types 2019-09-17 18:58:34 +01:00
David Pfau
3e5f7869ea
Fixed prefix for np.linalg.solve 2019-09-16 21:27:55 +01:00
David Pfau
36aaaf9f55
Merge branch 'master' into master 2019-09-16 20:54:04 +01:00
David Pfau
d30108d773
Added jit to slogdet
This resolves the prior merge conflict
2019-09-16 20:52:36 +01:00
David Pfau
0171eb7c68
Replace jvp for det with jvp for slogdet 2019-09-14 14:35:27 +01:00
David Pfau
c2750c1be9
Extended jvp for det to handle inputs with >2 dims 2019-09-14 14:30:45 +01:00
Peter Hawkins
a5f67d553d Fix incorrect slogdet parity calculation in presence of batch dimensions. 2019-09-11 08:19:26 -04:00
David Pfau
b819bff1d3
Imports for custom_transforms and defjvp
Can't really do this if we don't have the right imports...
2019-09-05 15:22:36 +01:00
David Pfau
1c264db3a3
Slightly cleaner implementation of jvp for det
Replaced np.dot/np.linalg.inv with a single np.linalg.solve
2019-08-31 03:18:24 +01:00
David Pfau
ae2c39d081
Added custom jvp for np.linalg.det
For faster gradient computation for determinants, added closed-form expressing for Jacobian-vector product. Still needs a test.
2019-08-31 02:13:12 +01:00
Peter Hawkins
5ac356d680 Add support for batched triangular solve and LU decomposition on GPU using cuBlas. 2019-08-08 13:34:53 -04:00
Peter Hawkins
5336722ad7 Support rank > 2 inputs to np.linalg.norm if axis==None and ord==None.
Matches an undocumented NumPy behavior:
https://github.com/numpy/numpy/issues/14215
2019-08-07 09:21:07 -04:00
Peter Hawkins
efdcef88c7 Remove experimental warning from linalg routines.
There's no particular reason to scare people with the experimental warning any longer; we don't know of any bugs here.
2019-07-23 21:13:14 -04:00
Peter Hawkins
12e622bbc1 Implement np.linalg.inv in terms np.linalg.solve.
i.e. use an LU decomposition instead of a QR decomposition, now that an LU decomposition is available on all platforms.
2019-06-28 15:49:38 -04:00