2018-12-12 08:44:07 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
from functools import partial
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
import numpy as np
|
2019-12-09 09:56:26 -08:00
|
|
|
import textwrap
|
2020-01-24 16:52:40 -05:00
|
|
|
import operator
|
2020-01-18 08:26:23 -05:00
|
|
|
from typing import Tuple, Union, cast
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2020-03-29 20:48:08 -07:00
|
|
|
from jax import jit, vmap, custom_jvp
|
2018-12-20 15:37:34 -05:00
|
|
|
from .. import lax
|
2020-01-15 15:00:38 -08:00
|
|
|
from .. import ops
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from .. import lax_linalg
|
2019-11-15 10:02:51 -05:00
|
|
|
from .. import dtypes
|
2018-12-11 12:44:02 -05:00
|
|
|
from .lax_numpy import _not_implemented
|
2020-05-06 15:17:55 -07:00
|
|
|
from ._util import _wraps
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
from .vectorize import vectorize
|
2020-05-05 20:41:57 -04:00
|
|
|
from . import lax_numpy as jnp
|
2020-08-24 20:21:19 -04:00
|
|
|
from ..util import get_module_functions, canonicalize_axis
|
2020-06-03 14:18:48 -07:00
|
|
|
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
|
|
|
_H = lambda x: jnp.conj(jnp.swapaxes(x, -1, -2))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
|
|
|
|
def _promote_arg_dtypes(*args):
|
|
|
|
"""Promotes `args` to a common inexact type."""
|
|
|
|
def _to_inexact_type(type):
|
2020-05-05 20:41:57 -04:00
|
|
|
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
|
|
|
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
|
|
|
|
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
|
2019-01-07 18:10:08 -05:00
|
|
|
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
|
|
|
if len(args) == 1:
|
|
|
|
return args[0]
|
|
|
|
else:
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.cholesky)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def cholesky(a):
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return lax_linalg.cholesky(a)
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.svd)
|
2019-01-05 11:13:08 +05:30
|
|
|
def svd(a, full_matrices=True, compute_uv=True):
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
2019-01-05 11:13:08 +05:30
|
|
|
return lax_linalg.svd(a, full_matrices, compute_uv)
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.matrix_power)
|
2020-01-24 16:52:40 -05:00
|
|
|
def matrix_power(a, n):
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
2020-01-24 16:52:40 -05:00
|
|
|
|
|
|
|
if a.ndim < 2:
|
|
|
|
raise TypeError("{}-dimensional array given. Array must be at least "
|
|
|
|
"two-dimensional".format(a.ndim))
|
|
|
|
if a.shape[-2] != a.shape[-1]:
|
|
|
|
raise TypeError("Last 2 dimensions of the array must be square")
|
|
|
|
try:
|
|
|
|
n = operator.index(n)
|
|
|
|
except TypeError:
|
|
|
|
raise TypeError("exponent must be an integer, got {}".format(n))
|
|
|
|
|
|
|
|
if n == 0:
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
|
2020-01-24 16:52:40 -05:00
|
|
|
elif n < 0:
|
|
|
|
a = inv(a)
|
2020-07-30 12:59:36 -07:00
|
|
|
n = np.abs(n)
|
2020-01-24 16:52:40 -05:00
|
|
|
|
|
|
|
if n == 1:
|
|
|
|
return a
|
|
|
|
elif n == 2:
|
|
|
|
return a @ a
|
|
|
|
elif n == 3:
|
|
|
|
return (a @ a) @ a
|
|
|
|
|
|
|
|
z = result = None
|
|
|
|
while n > 0:
|
|
|
|
z = a if z is None else (z @ z)
|
|
|
|
n, bit = divmod(n, 2)
|
|
|
|
if bit:
|
|
|
|
result = z if result is None else (result @ z)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.matrix_rank)
|
2020-01-26 14:29:33 -05:00
|
|
|
def matrix_rank(M, tol=None):
|
2020-05-05 20:41:57 -04:00
|
|
|
M = _promote_arg_dtypes(jnp.asarray(M))
|
2020-01-26 14:29:33 -05:00
|
|
|
if M.ndim > 2:
|
|
|
|
raise TypeError("array should have 2 or fewer dimensions")
|
|
|
|
if M.ndim < 2:
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.any(M != 0).astype(jnp.int32)
|
2020-01-26 14:29:33 -05:00
|
|
|
S = svd(M, full_matrices=False, compute_uv=False)
|
|
|
|
if tol is None:
|
2020-06-24 15:59:31 -04:00
|
|
|
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sum(S > tol)
|
2020-01-26 14:29:33 -05:00
|
|
|
|
|
|
|
|
2020-03-29 20:48:08 -07:00
|
|
|
@custom_jvp
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.slogdet)
|
2019-09-16 20:52:36 +01:00
|
|
|
@jit
|
2018-12-20 22:18:20 -05:00
|
|
|
def slogdet(a):
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2020-05-05 20:41:57 -04:00
|
|
|
a_shape = jnp.shape(a)
|
2018-12-20 22:18:20 -05:00
|
|
|
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
|
|
|
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
2018-12-20 15:37:34 -05:00
|
|
|
raise ValueError(msg.format(a_shape))
|
|
|
|
lu, pivot = lax_linalg.lu(a)
|
2020-05-05 20:41:57 -04:00
|
|
|
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
|
|
|
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
|
|
|
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
|
|
|
|
if jnp.iscomplexobj(a):
|
|
|
|
sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
else:
|
2020-05-05 20:41:57 -04:00
|
|
|
sign = jnp.array(1, dtype=dtype)
|
|
|
|
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
|
|
|
|
sign = jnp.where(is_zero,
|
|
|
|
jnp.array(0, dtype=dtype),
|
|
|
|
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
|
|
|
|
logdet = jnp.where(
|
|
|
|
is_zero, jnp.array(-jnp.inf, dtype=dtype),
|
|
|
|
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
|
|
|
|
return sign, jnp.real(logdet)
|
2020-03-29 20:48:08 -07:00
|
|
|
|
|
|
|
@slogdet.defjvp
|
|
|
|
def _slogdet_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
2020-05-05 20:41:57 -04:00
|
|
|
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
|
2020-03-29 20:48:08 -07:00
|
|
|
raise NotImplementedError # TODO(pfau): make this work for complex types
|
|
|
|
sign, ans = slogdet(x)
|
2020-05-05 20:41:57 -04:00
|
|
|
sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2)
|
2020-03-29 20:48:08 -07:00
|
|
|
return (sign, ans), (sign_dot, ans_dot)
|
2019-09-14 14:30:45 +01:00
|
|
|
|
|
|
|
|
2020-04-25 16:26:25 +01:00
|
|
|
def _cofactor_solve(a, b):
|
|
|
|
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
|
|
|
|
|
|
|
Intermediate function used for jvp and vjp of det.
|
|
|
|
This function borrows heavily from jax.numpy.linalg.solve and
|
|
|
|
jax.numpy.linalg.slogdet to compute the gradient of the determinant
|
|
|
|
in a way that is well defined even for low rank matrices.
|
|
|
|
|
|
|
|
This function handles two different cases:
|
|
|
|
* rank(a) == n or n-1
|
|
|
|
* rank(a) < n-1
|
|
|
|
|
|
|
|
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
|
|
|
|
Rather than computing det(a)*solve(a, b), which would return NaN, we work
|
|
|
|
directly with the LU decomposition. If a = p @ l @ u, then
|
|
|
|
det(a)*solve(a, b) =
|
|
|
|
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
|
|
|
|
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
|
|
|
|
If a is rank n-1, then the lower right corner of u will be zero and the
|
|
|
|
triangular_solve will fail.
|
|
|
|
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
|
2020-06-02 17:37:20 -07:00
|
|
|
Then y_{n}
|
2020-04-25 16:32:27 +01:00
|
|
|
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
|
|
|
|
x_{n} * prod_{i=1...n-1}(u_{ii})
|
2020-04-25 16:26:25 +01:00
|
|
|
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
|
|
|
|
we can avoid the triangular_solve failing.
|
2020-04-25 16:32:27 +01:00
|
|
|
To correctly compute the rest of y_{i} for i != n, we simply multiply
|
|
|
|
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2020-04-25 16:26:25 +01:00
|
|
|
For the second case, a check is done on the matrix to see if `solve`
|
|
|
|
returns NaN or Inf, and gives a matrix of zeros as a result, as the
|
|
|
|
gradient of the determinant of a matrix with rank less than n-1 is 0.
|
|
|
|
This will still return the correct value for rank n-1 matrices, as the check
|
|
|
|
is applied *after* the lower right corner of u has been updated.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: A square matrix or batch of matrices, possibly singular.
|
|
|
|
b: A matrix, or batch of matrices of the same dimension as a.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
|
|
|
|
"""
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
|
|
|
b = _promote_arg_dtypes(jnp.asarray(b))
|
|
|
|
a_shape = jnp.shape(a)
|
|
|
|
b_shape = jnp.shape(b)
|
2020-04-25 16:26:25 +01:00
|
|
|
a_ndims = len(a_shape)
|
|
|
|
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
|
|
|
|
and b_shape[-2:] == a_shape[-2:]):
|
|
|
|
msg = ("The arguments to _cofactor_solve must have shapes "
|
|
|
|
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
|
|
|
|
raise ValueError(msg.format(a_shape, b_shape))
|
|
|
|
if a_shape[-1] == 1:
|
|
|
|
return a[0, 0], b
|
|
|
|
# lu contains u in the upper triangular matrix and l in the strict lower
|
|
|
|
# triangular matrix.
|
|
|
|
# The diagonal of l is set to ones without loss of generality.
|
|
|
|
lu, pivots = lax_linalg.lu(a)
|
|
|
|
dtype = lax.dtype(a)
|
|
|
|
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
|
2020-05-05 20:41:57 -04:00
|
|
|
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
|
|
|
|
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
|
2020-04-25 16:26:25 +01:00
|
|
|
# Compute (partial) determinant, ignoring last diagonal of LU
|
2020-05-05 20:41:57 -04:00
|
|
|
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
|
|
|
parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
|
|
|
|
sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
|
2020-04-25 16:26:25 +01:00
|
|
|
# partial_det[:, -1] contains the full determinant and
|
|
|
|
# partial_det[:, -2] contains det(u) / u_{nn}.
|
2020-05-05 20:41:57 -04:00
|
|
|
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
|
2020-04-25 16:26:25 +01:00
|
|
|
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
|
|
|
|
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1])
|
2020-05-05 20:41:57 -04:00
|
|
|
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
|
|
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
2020-04-25 16:26:25 +01:00
|
|
|
# filter out any matrices that are not full rank
|
2020-05-05 20:41:57 -04:00
|
|
|
d = jnp.ones(x.shape[:-1], x.dtype)
|
2020-04-25 16:26:25 +01:00
|
|
|
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
|
2020-05-05 20:41:57 -04:00
|
|
|
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
|
|
|
|
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
|
|
|
|
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
|
2020-04-25 16:26:25 +01:00
|
|
|
x = x[iotas[:-1] + (permutation, slice(None))]
|
|
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
|
|
|
unit_diagonal=True)
|
2020-05-05 20:41:57 -04:00
|
|
|
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
|
2020-04-25 16:26:25 +01:00
|
|
|
x[..., -1:, :]), axis=-2)
|
|
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
2020-05-05 20:41:57 -04:00
|
|
|
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
|
2020-04-25 16:26:25 +01:00
|
|
|
|
|
|
|
return partial_det[..., -1], x
|
|
|
|
|
|
|
|
|
|
|
|
@custom_jvp
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.det)
|
2018-12-20 22:18:20 -05:00
|
|
|
def det(a):
|
|
|
|
sign, logdet = slogdet(a)
|
2020-05-05 20:41:57 -04:00
|
|
|
return sign * jnp.exp(logdet)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
|
2020-04-25 16:26:25 +01:00
|
|
|
@det.defjvp
|
|
|
|
def _det_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
y, z = _cofactor_solve(x, g)
|
2020-05-05 20:41:57 -04:00
|
|
|
return y, jnp.trace(z, axis1=-1, axis2=-2)
|
2020-04-25 16:26:25 +01:00
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.eig)
|
2019-05-13 15:59:58 -04:00
|
|
|
def eig(a):
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
2019-05-13 15:59:58 -04:00
|
|
|
w, vl, vr = lax_linalg.eig(a)
|
|
|
|
return w, vr
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.eigvals)
|
2019-10-30 19:29:56 -07:00
|
|
|
def eigvals(a):
|
|
|
|
w, _ = eig(a)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.eigh)
|
2019-02-13 23:44:41 -08:00
|
|
|
def eigh(a, UPLO=None, symmetrize_input=True):
|
2019-01-07 18:10:08 -05:00
|
|
|
if UPLO is None or UPLO == "L":
|
|
|
|
lower = True
|
|
|
|
elif UPLO == "U":
|
|
|
|
lower = False
|
|
|
|
else:
|
|
|
|
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
2019-02-13 23:44:41 -08:00
|
|
|
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
2019-01-07 18:10:08 -05:00
|
|
|
return w, v
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.eigvalsh)
|
2019-10-30 19:29:56 -07:00
|
|
|
def eigvalsh(a, UPLO='L'):
|
|
|
|
w, _ = eigh(a, UPLO)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2020-04-22 20:15:04 -04:00
|
|
|
@partial(custom_jvp, nondiff_argnums=(1,))
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
2019-12-09 09:56:26 -08:00
|
|
|
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
|
|
|
default `rcond` is `1e-15`. Here the default is
|
2020-05-05 20:41:57 -04:00
|
|
|
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
2019-12-09 09:56:26 -08:00
|
|
|
"""))
|
2019-12-03 11:15:39 -08:00
|
|
|
def pinv(a, rcond=None):
|
2020-04-22 20:15:04 -04:00
|
|
|
# Uses same algorithm as
|
|
|
|
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
2020-05-05 20:41:57 -04:00
|
|
|
a = jnp.conj(a)
|
2019-12-03 11:15:39 -08:00
|
|
|
if rcond is None:
|
2020-04-22 20:15:04 -04:00
|
|
|
max_rows_cols = max(a.shape[-2:])
|
2020-05-05 20:41:57 -04:00
|
|
|
rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
|
|
|
|
rcond = jnp.asarray(rcond)
|
2019-12-03 11:15:39 -08:00
|
|
|
u, s, v = svd(a, full_matrices=False)
|
2020-01-24 16:52:40 -05:00
|
|
|
# Singular values less than or equal to ``rcond * largest_singular_value``
|
2019-12-03 11:15:39 -08:00
|
|
|
# are set to zero.
|
2020-05-05 20:41:57 -04:00
|
|
|
cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True)
|
|
|
|
s = jnp.where(s > cutoff, s, jnp.inf)
|
|
|
|
res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
|
2019-12-03 11:15:39 -08:00
|
|
|
return lax.convert_element_type(res, a.dtype)
|
|
|
|
|
|
|
|
|
2020-04-22 20:15:04 -04:00
|
|
|
@pinv.defjvp
|
|
|
|
def _pinv_jvp(rcond, primals, tangents):
|
|
|
|
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
|
|
|
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
|
|
|
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
|
|
|
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
|
|
|
p = pinv(a, rcond=rcond)
|
|
|
|
m, n = a.shape[-2:]
|
|
|
|
# TODO(phawkins): on TPU, we would need to opt into high precision here.
|
|
|
|
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
|
|
|
|
p_dot = -p @ a_dot @ p
|
2020-05-05 20:41:57 -04:00
|
|
|
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
|
|
|
|
p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
|
2020-04-22 20:15:04 -04:00
|
|
|
return p, p_dot
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.inv)
|
2018-12-13 19:28:05 -05:00
|
|
|
def inv(a):
|
2020-05-05 20:41:57 -04:00
|
|
|
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
2018-12-13 21:02:24 -05:00
|
|
|
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
|
2020-05-05 20:41:57 -04:00
|
|
|
.format(jnp.shape(a)))
|
2019-06-28 15:49:38 -04:00
|
|
|
return solve(
|
2020-05-05 20:41:57 -04:00
|
|
|
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
2018-12-13 19:28:05 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
@partial(jit, static_argnums=(1, 2, 3))
|
2020-01-18 08:26:23 -05:00
|
|
|
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
2020-05-05 20:41:57 -04:00
|
|
|
x = _promote_arg_dtypes(jnp.asarray(x))
|
|
|
|
x_shape = jnp.shape(x)
|
2019-02-07 10:51:55 -05:00
|
|
|
ndim = len(x_shape)
|
|
|
|
|
|
|
|
if axis is None:
|
2019-08-07 09:21:07 -04:00
|
|
|
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
|
|
|
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
|
|
|
if ord is None:
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
|
2019-02-07 10:51:55 -05:00
|
|
|
axis = tuple(range(ndim))
|
|
|
|
elif isinstance(axis, tuple):
|
2020-08-24 20:21:19 -04:00
|
|
|
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
|
2019-02-07 10:51:55 -05:00
|
|
|
else:
|
2020-08-24 20:21:19 -04:00
|
|
|
axis = (canonicalize_axis(axis, ndim),)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
num_axes = len(axis)
|
|
|
|
if num_axes == 1:
|
|
|
|
if ord is None or ord == 2:
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
2019-02-07 10:51:55 -05:00
|
|
|
keepdims=keepdims))
|
2020-05-05 20:41:57 -04:00
|
|
|
elif ord == jnp.inf:
|
|
|
|
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == -jnp.inf:
|
|
|
|
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
|
2019-02-07 10:51:55 -05:00
|
|
|
elif ord == 0:
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == 1:
|
|
|
|
# Numpy has a special case for ord == 1 as an optimization. We don't
|
|
|
|
# really need the optimization (XLA could do it for us), but the Numpy
|
|
|
|
# code has slightly different type promotion semantics, so we need a
|
|
|
|
# special case too.
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
|
2019-02-07 10:51:55 -05:00
|
|
|
else:
|
2020-05-05 20:41:57 -04:00
|
|
|
abs_x = jnp.abs(x)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
ord = lax._const(abs_x, ord)
|
2020-05-05 20:41:57 -04:00
|
|
|
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
|
|
|
|
return jnp.power(out, 1. / ord)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
elif num_axes == 2:
|
2020-01-18 08:26:23 -05:00
|
|
|
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
2019-02-07 10:51:55 -05:00
|
|
|
if ord is None or ord in ('f', 'fro'):
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
2019-02-07 10:51:55 -05:00
|
|
|
keepdims=keepdims))
|
|
|
|
elif ord == 1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=col_axis, keepdims=keepdims)
|
|
|
|
elif ord == -1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=col_axis, keepdims=keepdims)
|
2020-05-05 20:41:57 -04:00
|
|
|
elif ord == jnp.inf:
|
2019-02-07 10:51:55 -05:00
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=row_axis, keepdims=keepdims)
|
2020-05-05 20:41:57 -04:00
|
|
|
elif ord == -jnp.inf:
|
2019-02-07 10:51:55 -05:00
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=row_axis, keepdims=keepdims)
|
|
|
|
elif ord in ('nuc', 2, -2):
|
2020-05-05 20:41:57 -04:00
|
|
|
x = jnp.moveaxis(x, axis, (-2, -1))
|
2019-02-07 10:51:55 -05:00
|
|
|
if ord == 2:
|
2020-05-05 20:41:57 -04:00
|
|
|
reducer = jnp.amax
|
2019-02-07 10:51:55 -05:00
|
|
|
elif ord == -2:
|
2020-05-05 20:41:57 -04:00
|
|
|
reducer = jnp.amin
|
2019-02-07 10:51:55 -05:00
|
|
|
else:
|
2020-05-05 20:41:57 -04:00
|
|
|
reducer = jnp.sum
|
2019-02-07 10:51:55 -05:00
|
|
|
y = reducer(svd(x, compute_uv=False), axis=-1)
|
|
|
|
if keepdims:
|
|
|
|
result_shape = list(x_shape)
|
|
|
|
result_shape[axis[0]] = 1
|
|
|
|
result_shape[axis[1]] = 1
|
2020-05-05 20:41:57 -04:00
|
|
|
y = jnp.reshape(y, result_shape)
|
2019-02-07 10:51:55 -05:00
|
|
|
return y
|
|
|
|
else:
|
|
|
|
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
2020-05-05 20:41:57 -04:00
|
|
|
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
|
2019-02-07 10:51:55 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.norm)
|
2019-08-07 09:21:07 -04:00
|
|
|
def norm(x, ord=None, axis=None, keepdims=False):
|
|
|
|
return _norm(x, ord, axis, keepdims)
|
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.qr)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr(a, mode="reduced"):
|
|
|
|
if mode in ("reduced", "r", "full"):
|
|
|
|
full_matrices = False
|
|
|
|
elif mode == "complete":
|
|
|
|
full_matrices = True
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
2020-05-05 20:41:57 -04:00
|
|
|
a = _promote_arg_dtypes(jnp.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
q, r = lax_linalg.qr(a, full_matrices)
|
|
|
|
if mode == "r":
|
|
|
|
return r
|
|
|
|
return q, r
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
def _check_solve_shapes(a, b):
|
2020-02-12 17:05:18 -08:00
|
|
|
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
|
2018-12-21 16:29:45 -05:00
|
|
|
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
|
|
|
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
2020-02-12 17:05:18 -08:00
|
|
|
raise ValueError(msg.format(a.shape, b.shape))
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
@partial(vectorize, signature='(n,m),(m)->(n)')
|
|
|
|
def _matvec_multiply(a, b):
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
@_wraps(np.linalg.solve)
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
@jit
|
|
|
|
def solve(a, b):
|
2020-05-05 20:41:57 -04:00
|
|
|
a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
_check_solve_shapes(a, b)
|
|
|
|
|
|
|
|
# With custom_linear_solve, we can reuse the same factorization when
|
|
|
|
# computing sensitivities. This is considerably faster.
|
2020-04-17 12:42:53 -07:00
|
|
|
lu, pivots = lax_linalg.lu(lax.stop_gradient(a))
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
custom_solve = partial(
|
|
|
|
lax.custom_linear_solve,
|
|
|
|
lambda x: _matvec_multiply(a, x),
|
|
|
|
solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=0),
|
|
|
|
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=1))
|
|
|
|
if a.ndim == b.ndim + 1:
|
|
|
|
# b.shape == [..., m]
|
|
|
|
return custom_solve(b)
|
|
|
|
else:
|
|
|
|
# b.shape == [..., m, k]
|
|
|
|
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
|
2020-05-11 14:53:17 -07:00
|
|
|
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
|
|
|
It has two important differences:
|
|
|
|
|
|
|
|
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
|
|
|
the default will be `None`. Here, the default rcond is `None`.
|
|
|
|
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
|
|
|
|
solutions. Here, the residuals are returned in all cases, to make the function
|
|
|
|
compatible with jit. The non-jit compatible numpy behavior can be recovered by
|
|
|
|
passing numpy_resid=True.
|
|
|
|
|
|
|
|
The lstsq function does not currently have a custom JVP rule, so the gradient is
|
|
|
|
poorly behaved for some inputs, particularly for low-rank `a`.
|
|
|
|
"""))
|
|
|
|
def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
|
|
|
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
|
|
|
# TODO: add custom jvp rule for more robust lstsq differentiation
|
|
|
|
a, b = _promote_arg_dtypes(a, b)
|
|
|
|
if a.shape[0] != b.shape[0]:
|
|
|
|
raise ValueError("Leading dimensions of input arrays must match")
|
|
|
|
b_orig_ndim = b.ndim
|
|
|
|
if b_orig_ndim == 1:
|
|
|
|
b = b[:, None]
|
|
|
|
if a.ndim != 2:
|
|
|
|
raise TypeError(
|
|
|
|
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
|
|
|
|
if b.ndim != 2:
|
|
|
|
raise TypeError(
|
2020-05-12 09:06:22 +03:00
|
|
|
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
2020-05-11 14:53:17 -07:00
|
|
|
m, n = a.shape
|
|
|
|
dtype = a.dtype
|
|
|
|
if rcond is None:
|
|
|
|
rcond = jnp.finfo(dtype).eps * max(n, m)
|
|
|
|
elif rcond < 0:
|
|
|
|
rcond = jnp.finfo(dtype).eps
|
|
|
|
u, s, vt = svd(a, full_matrices=False)
|
|
|
|
mask = s >= rcond * s[0]
|
|
|
|
rank = mask.sum()
|
|
|
|
safe_s = jnp.where(mask, s, 1)
|
|
|
|
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
|
|
|
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
|
|
|
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
|
|
|
# Numpy returns empty residuals in some cases. To allow compilation, we
|
|
|
|
# default to returning full residuals in all cases.
|
|
|
|
if numpy_resid and (rank < n or m <= n):
|
|
|
|
resid = jnp.asarray([])
|
|
|
|
else:
|
|
|
|
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
|
|
|
resid = norm(b - b_estimate, axis=0) ** 2
|
|
|
|
if b_orig_ndim == 1:
|
|
|
|
x = x.ravel()
|
|
|
|
return x, resid, rank, s
|
2020-07-09 16:31:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
_NOT_IMPLEMENTED = []
|
|
|
|
for name, func in get_module_functions(np.linalg).items():
|
|
|
|
if name not in globals():
|
|
|
|
_NOT_IMPLEMENTED.append(name)
|
|
|
|
globals()[name] = _not_implemented(func)
|