* Add options to compute L/R eigenvectors in geev.
The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.
Reformulate eig-related operations based on the new geev API.
* Addressed hawkinsp's comments from google/jax#3882.
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* Add vjp and jvp rules for jnp.linalg.det
* Add tests for new determinant gradients
* Replace index_update with concatenate in cofactor_solve
This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where).
* Changes to cofactor_solve so it can be transposed
This allows a single JVP rule to give both forward and backward derivatives
* Update det grad tests
All tests pass now - however second derivatives still do not work for nonsingular matrices.
* Add explanation to docstring for _cofactor_solve
* Fixed comment
* Add explicit derivative for jax.numpy.linalg.pinv.
* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
* Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve
This single implementation supports broadcasting like NumPy in both the NumPy
and SciPy interfaces to LU solve, even though only original NumPy supports
broadcasting.
This change is technical backwards incompatible in the SciPy wrapper, which
previously supported adding extra dimensions to the end of `b`, e.g.,
`b.shape == (8, 4, 2)` when `a.shape == (8, 8)`. There was a testcase for this,
but it isn't documented in either JAX or SciPy.
* fix recursive import
* Use np.vectorize instead of experimental.vectorize
* add np.linalg.cond in a third_party module
* remove unnecessary type maps
* rename cond.py to linalg.py for consistency
* shift LICENSE to correct directory; formatting changes; completed testing
* Add implementation and testing for tensorinv
* fix tests for tensorinv to stop build stalling
* blank __init__.py; add extra testing for cond; remove index assignments
* jax.lax.cond is breaking on jax.numpy.linalg.norm
* fix control flow issues; update tests to use curried functions
* clean up imports; remove commented code
* remove control flow from tests; remove unneeded functions
* Implement np.linalg.matrix_rank
* Test np.linalg.matrix_rank
* Use helper numpy testing function
* Fix issue with 1D matrix rank procedure
* Add new tests for 1D matrices and jit
* Do not check dtypes to circumvent int32 vs int64
* Include documentation for matrix_rank
* Fix ordering
* Use np.sum
* Implement numpy.linalg.matrix_power
* Write tests for numpy.linalg.matrix_power
* Check for input matrix shapes
* Move to matrix-multiplication operator in matrix power
* Improve error messages and directly use broadcasting
* Include matrix_power in documentation
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.